<a href="https://colab.research.google.com/github/Dhyanesh-Panchal/amazon_ml_2024_deepinsight/blob/master/solution/Multi_input_endoder_decoder_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

graph TD
    A[Input Image] --> B[Vision Encoder]
    C[Text Query] --> D[Text Encoder]
    B --> E[Cross-Attention Transformer]
    D --> E
    E --> F[Decoder]
    F --> G[Output Text]

In [70]:
import torch
import torch.nn as nn
from transformers import ViTModel

import torchvision.models as models
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.nn.utils.rnn import pad_sequence

from PIL import Image
import requests
from io import BytesIO

import pandas as pd
import numpy as np
import re
from typing import List, Tuple
from tqdm import tqdm

from sklearn.model_selection import train_test_split

## Model & Supporter Classes

In [71]:
class OutputTokenizer:
    def __init__(self, unit_list: List[str]):
        self.unit_list = unit_list
        self.unit_to_id = {unit: i for i, unit in enumerate(unit_list)}
        self.id_to_unit = {i: unit for i, unit in enumerate(unit_list)}
        self.num_units = len(unit_list)

        # Special tokens
        self.PAD_token = 0
        self.SOS_token = 1
        self.EOS_token = 2
        self.UNK_token = 3 # for Unknown tokens

        # Vocabulary size: special tokens + digits + decimal point + units
        self.vocab_size = 4 + 10 + 1 + self.num_units

    def tokenize(self, text: str, max_length: int = 50) -> List[int]:
        # Split the input into number and unit
        match = re.match(r'(\d+\.?\d*)\s*(\w+)', text.strip())
        if not match:
            return [self.SOS_token, self.UNK_token, self.EOS_token] + [self.PAD_token] * (max_length - 3)

        number, unit = match.groups()

        # Tokenize the number
        number_tokens = [int(digit) + 4 for digit in number if digit.isdigit()]
        if '.' in number:
            number_tokens.insert(number.index('.'), 14)  # 14 is the token for decimal point

        # Tokenize the unit
        unit_token = self.unit_to_id.get(unit, self.UNK_token)

        # Combine tokens
        tokens = [self.SOS_token] + number_tokens + [unit_token + 15] + [self.EOS_token]

        # Pad or truncate to max_length
        if len(tokens) < max_length:
            tokens += [self.PAD_token] * (max_length - len(tokens))
        else:
            tokens = tokens[:max_length-1] + [self.EOS_token]

        return tokens

    def detokenize(self, tokens: List[int]) -> str:
        number = ''
        unit = ''
        for token in tokens:
            if token == self.SOS_token or token == self.PAD_token:
                continue
            elif token == self.EOS_token:
                break
            elif 4 <= token <= 13:
                number += str(token - 4)
            elif token == 14:
                number += '.'
            elif token >= 15:
                unit = self.id_to_unit.get(token - 15, 'UNK')

        return f"{number} {unit}"

    def decode(self, token_ids: List[int]) -> str:
        return self.detokenize(token_ids)


In [72]:
class EntityPredictorVisionBasedModel_v2(nn.Module):
    def __init__(self, num_entity_names, num_group_ids, num_units):
        super().__init__()

        # Specialized Image Processing
        self.cnn = models.resnet50(pretrained=True)
        self.cnn.fc = nn.Identity()  # Remove the final fully connected layer

        self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')

        # Combine CNN and ViT features
        self.feature_combiner = nn.Linear(2048 + 768, 1024)

        # Entity and Group Embeddings
        self.entity_embedding = nn.Embedding(num_entity_names, 256)
        self.group_embedding = nn.Embedding(num_group_ids, 256)

        # Feature Fusion
        self.fusion_layer = nn.Linear(1024 + 512, 512)

        # Explicit Numerical Reasoning
        self.numerical_layer = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 1)  # Output a single number
        )

        # Unit Classification
        self.unit_classifier = nn.Linear(512, num_units)

    def forward(self, images, entity_names, group_ids):
        # Process images
        cnn_features = self.cnn(images)
        vit_features = self.vit(images).last_hidden_state[:, 0, :]  # Use [CLS] token
        combined_image_features = torch.cat((cnn_features, vit_features), dim=1)
        image_features = self.feature_combiner(combined_image_features)

        # Process entity and group information
        entity_features = self.entity_embedding(entity_names)
        group_features = self.group_embedding(group_ids)
        context_features = torch.cat((entity_features, group_features), dim=1)

        # Fuse all features
        fused_features = self.fusion_layer(torch.cat((image_features, context_features), dim=1))

        # Predict numerical value
        numerical_value = self.numerical_layer(fused_features)

        # Predict unit
        unit_logits = self.unit_classifier(fused_features)

        return numerical_value, unit_logits

In [73]:
class ProductImageDataset(Dataset):
    def __init__(self, df, unit_to_index, transform=None):
        self.df = df
        self.unit_to_index = unit_to_index
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        # Load image
        response = requests.get(row['image_link'], timeout=10)
        response.raise_for_status()
        img = Image.open(BytesIO(response.content)).convert('RGB')
        img = self.transform(img)

        # Get other data
        entity_name = torch.tensor(row['entity_name_index'], dtype=torch.long)
        group_id = torch.tensor(row['group_id_index'], dtype=torch.long)

        # Parse the entity_value into numerical value and unit
        value, unit = self.parse_entity_value(row['entity_value'])

        return img, entity_name, group_id, value, unit

    def parse_entity_value(self, entity_value):
        # Implement parsing logic here
        parts = entity_value.split()
        value = float(parts[0])
        unit_str = parts[1]

        # Convert unit string to index
        unit_index = self.unit_to_index.get(unit_str, 0)  # Use 0 as default if unit not found

        return torch.tensor(value, dtype=torch.float), torch.tensor(unit_index, dtype=torch.long)

## Training

In [74]:
# Set random seed for reproducibility
torch.manual_seed(42)

<torch._C.Generator at 0x7af954b28c70>

In [75]:
# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


In [76]:
df = pd.read_csv('https://raw.githubusercontent.com/Dhyanesh-Panchal/amazon_ml_2024_deepinsight/refs/heads/master/student_resource%203/dataset/filtered_train.csv')
df.head(2)

Unnamed: 0,index,image_link,group_id,entity_name,entity_value
0,0,https://m.media-amazon.com/images/I/61I9XdN6OF...,748919,item_weight,500.0 gram
1,1,https://m.media-amazon.com/images/I/71gSRbyXmo...,916768,item_volume,1.0 cup


In [77]:
df = df.dropna(subset=['entity_value'])

In [78]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 260311 entries, 0 to 260310
Data columns (total 5 columns):
 #   Column        Non-Null Count   Dtype 
---  ------        --------------   ----- 
 0   index         260311 non-null  int64 
 1   image_link    260311 non-null  object
 2   group_id      260311 non-null  int64 
 3   entity_name   260311 non-null  object
 4   entity_value  260311 non-null  object
dtypes: int64(2), object(3)
memory usage: 9.9+ MB


In [79]:
# Create mappings for entity names and group IDs
entity_to_index = {entity: idx for idx, entity in enumerate(df['entity_name'].unique())}
group_to_index = {group: idx for idx, group in enumerate(df['group_id'].unique())}

df['entity_name_index'] = df['entity_name'].map(entity_to_index)
df['group_id_index'] = df['group_id'].map(group_to_index)

In [80]:
# Prepare unit list
df['entity_value_unit'] = df['entity_value'].apply(lambda x: " ".join(x.split(" ")[1:]))
unit_to_index = {unit: idx for idx, unit in enumerate(df['entity_value_unit'].unique())}

In [81]:
# Split the data
train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)

In [82]:
# Create datasets
train_dataset = ProductImageDataset(train_df, unit_to_index)
val_dataset = ProductImageDataset(val_df, unit_to_index)

In [83]:
def custom_collate(batch):
    # Separate the batch into individual components
    images, entity_names, group_ids, targets = zip(*batch)

    # Stack images, entity_names, and group_ids (assuming they're already tensors of uniform size)
    images = torch.stack(images, 0)
    entity_names = torch.stack(entity_names, 0)
    group_ids = torch.stack(group_ids, 0)

    # Pad the target sequences
    targets = pad_sequence(targets, batch_first=True, padding_value=0)

    return images, entity_names, group_ids, targets

In [84]:
# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)




In [85]:
# Initialise the Model
model = EntityPredictorVisionBasedModel_v2(
    num_entity_names=len(entity_to_index),
    num_group_ids=len(group_to_index),
    num_units=len(unit_to_index)
)



In [86]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [87]:
def train_model(model, train_loader, val_loader, num_epochs, device):
    model.to(device)

    # Define loss functions
    value_criterion = nn.MSELoss()
    unit_criterion = nn.CrossEntropyLoss()

    # Define optimizer
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0

        for images, entity_names, group_ids, true_values, true_units in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            images = images.to(device)
            entity_names = entity_names.to(device)
            group_ids = group_ids.to(device)
            true_values = true_values.to(device)
            true_units = true_units.to(device)

            # Forward pass
            pred_values, pred_units = model(images, entity_names, group_ids)

            # Compute losses
            value_loss = value_criterion(pred_values.squeeze(), true_values)
            unit_loss = unit_criterion(pred_units, true_units)

            # Combine losses
            loss = value_loss + unit_loss

            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")

        # Validation
        model.eval()
        val_loss = 0.0
        correct_units = 0
        total_samples = 0

        with torch.no_grad():
            for images, entity_names, group_ids, true_values, true_units in val_loader:
                images = images.to(device)
                entity_names = entity_names.to(device)
                group_ids = group_ids.to(device)
                true_values = true_values.to(device)
                true_units = true_units.to(device)

                pred_values, pred_units = model(images, entity_names, group_ids)

                value_loss = value_criterion(pred_values.squeeze(), true_values)
                unit_loss = unit_criterion(pred_units, true_units)
                loss = value_loss + unit_loss

                val_loss += loss.item()

                _, predicted_units = torch.max(pred_units, 1)
                correct_units += (predicted_units == true_units).sum().item()
                total_samples += true_units.size(0)

        avg_val_loss = val_loss / len(val_loader)
        unit_accuracy = correct_units / total_samples
        print(f"Validation Loss: {avg_val_loss:.4f}, Unit Accuracy: {unit_accuracy:.4f}")

    return model

In [88]:
# # Define loss function and optimizer
# criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.PAD_token)
# optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [89]:
num_epochs = 1

In [None]:
trained_model = train_model(model, train_loader, val_loader, num_epochs, device)

Epoch 1/1:   0%|          | 2/6508 [03:24<182:47:56, 101.15s/it]