<a href="https://colab.research.google.com/github/Dhyanesh-Panchal/amazon_ml_2024_deepinsight/blob/master/solution/Multi_input_endoder_decoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

graph TD
    A[Input Image] --> B[Vision Encoder]
    C[Text Query] --> D[Text Encoder]
    B --> E[Cross-Attention Transformer]
    D --> E
    E --> F[Decoder]
    F --> G[Output Text]

In [47]:
import torch
import torch.nn as nn
from transformers import ViTModel

import torch.optim as optim

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import requests
from io import BytesIO

import pandas as pd
import numpy as np
import re
from typing import List, Tuple
from tqdm import tqdm

from sklearn.model_selection import train_test_split

## Model & Supporter Classes

In [14]:
class OutputTokenizer:
    def __init__(self, unit_list: List[str]):
        self.unit_list = unit_list
        self.unit_to_id = {unit: i for i, unit in enumerate(unit_list)}
        self.id_to_unit = {i: unit for i, unit in enumerate(unit_list)}
        self.num_units = len(unit_list)

        # Special tokens
        self.PAD_token = 0
        self.SOS_token = 1
        self.EOS_token = 2
        self.UNK_token = 3

        # Vocabulary size: special tokens + digits + decimal point + units
        self.vocab_size = 4 + 10 + 1 + self.num_units

    def tokenize(self, text: str) -> List[int]:
        # Split the input into number and unit
        match = re.match(r'(\d+\.?\d*)\s*(\w+)', text.strip())
        if not match:
            return [self.UNK_token]

        number, unit = match.groups()

        # Tokenize the number
        number_tokens = [int(digit) + 4 for digit in number if digit.isdigit()]
        if '.' in number:
            number_tokens.insert(number.index('.'), 14)  # 14 is the token for decimal point

        # Tokenize the unit
        unit_token = self.unit_to_id.get(unit, self.UNK_token)

        return [self.SOS_token] + number_tokens + [unit_token + 15] + [self.EOS_token]

    def detokenize(self, tokens: List[int]) -> str:
        number = ''
        unit = ''
        for token in tokens:
            if token == self.SOS_token or token == self.PAD_token:
                continue
            elif token == self.EOS_token:
                break
            elif 4 <= token <= 13:
                number += str(token - 4)
            elif token == 14:
                number += '.'
            elif token >= 15:
                unit = self.id_to_unit.get(token - 15, 'UNK')

        return f"{number} {unit}"


In [41]:
class EntityPredictorVisionBasedModel(nn.Module):
    def __init__(self, num_entity_names, num_group_ids, unit_list: List[str], max_length=50):
        super().__init__()
        self.image_encoder = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')

        self.entity_embedding = nn.Embedding(num_entity_names, 768)
        self.group_embedding = nn.Embedding(num_group_ids, 768)

        self.cross_attention = nn.MultiheadAttention(embed_dim=768, num_heads=8)

        self.decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model=768, nhead=8),
            num_layers=6
        )

        self.tokenizer = OutputTokenizer(unit_list)
        self.output_layer = nn.Linear(768, self.tokenizer.vocab_size)

        self.max_length = max_length

    def forward(self, images, entity_names, group_ids, target_texts=None):
        # Encode images
        image_features = self.image_encoder(images).last_hidden_state  # Shape: (batch_size, num_patches, 768)

        # Encode entity names and group IDs
        entity_embeddings = self.entity_embedding(entity_names)  # Shape: (batch_size, 768)
        group_embeddings = self.group_embedding(group_ids)  # Shape: (batch_size, 768)

        # Combine entity and group embeddings
        query = entity_embeddings + group_embeddings
        query = query.unsqueeze(1)  # Shape: (batch_size, 1, 768)

        # Cross-attention
        attended_features, _ = self.cross_attention(query, image_features, image_features)

        # Prepare decoder input
        if self.training and target_texts is not None:
            decoder_inputs = torch.tensor([self.tokenizer.tokenize(text) for text in target_texts], device=images.device)
            decoder_inputs = decoder_inputs[:, :-1]  # Remove EOS token
        else:
            decoder_inputs = torch.full((images.size(0), 1), self.tokenizer.SOS_token, dtype=torch.long, device=images.device)

        # Decode
        decoder_output = self.decoder(decoder_inputs, attended_features)

        # Generate output probabilities
        output_probs = self.output_layer(decoder_output)

        return output_probs

    def generate(self, images, entity_names, group_ids):
        self.eval()
        with torch.no_grad():
            batch_size = images.size(0)
            decoder_input = torch.full((batch_size, 1), self.tokenizer.SOS_token, dtype=torch.long, device=images.device)

            generated_tokens = []
            for _ in range(self.max_length):
                output_probs = self(images, entity_names, group_ids, decoder_input)
                next_token = output_probs[:, -1, :].argmax(dim=-1)
                generated_tokens.append(next_token.unsqueeze(1))
                decoder_input = torch.cat([decoder_input, next_token.unsqueeze(1)], dim=1)

                if (next_token == self.tokenizer.EOS_token).all():
                    break

            generated_tokens = torch.cat(generated_tokens, dim=1)
            generated_texts = [self.tokenizer.detokenize(tokens) for tokens in generated_tokens.tolist()]

        return generated_texts

In [16]:
# Helper functions
def entity_to_index(entity_names):
    unique_entities = sorted(set(entity_names))
    entity_to_idx = {entity: idx for idx, entity in enumerate(unique_entities)}
    return [entity_to_idx[entity] for entity in entity_names]

def group_to_index(group_ids):
    unique_groups = sorted(set(group_ids))
    group_to_idx = {group: idx for idx, group in enumerate(unique_groups)}
    return [group_to_idx[group] for group in group_ids]


In [51]:
class ProductImageDataset(Dataset):
    def __init__(self, df, tokenizer, transform=None):
        self.df = df
        self.tokenizer = tokenizer
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        try:
            row = self.df.iloc[idx]

            # Load image from URL
            response = requests.get(row['image_link'], timeout=10)
            response.raise_for_status()  # Raise an exception for bad responses
            img = Image.open(BytesIO(response.content)).convert('RGB')
            img = self.transform(img)

            # Prepare other inputs
            entity_name = torch.tensor(row['entity_name_index'])
            group_id = torch.tensor(row['group_id_index'])

            # Prepare target (if available)
            target = row.get('prediction', None)
            if target is not None:
                target = self.tokenizer.tokenize(target)
                target = torch.tensor(target)
            else:
                print(f"Warning: No target for index {idx}")
                target = torch.tensor([])  # Return an empty tensor instead of None

            return img, entity_name, group_id, target

        except Exception as e:
            print(f"Error loading data at index {idx}: {str(e)}")
            # Return a default value or skip this sample
            return torch.zeros((3, 224, 224)), torch.tensor(0), torch.tensor(0), torch.tensor([])

In [17]:
# Example usage
num_entity_names = len(set(df_x['entity_name']))
num_group_ids = len(set(df_x['group_id']))
unit_list = ['foot', 'kilovolt', 'kilowatt', 'ton', 'volt', 'watt', 'millimetre', 'centimetre', 'pound', 'inch', 'yard', 'metre', 'kilogram', 'gram', 'microgram']
model = UpdatedVisionLanguageModel(num_entity_names, num_group_ids, unit_list)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

## Training

In [23]:
# Set random seed for reproducibility
torch.manual_seed(42)

<torch._C.Generator at 0x7fb3b278cc30>

In [24]:
# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [25]:
df = pd.read_csv('https://raw.githubusercontent.com/Dhyanesh-Panchal/amazon_ml_2024_deepinsight/refs/heads/master/student_resource%203/dataset/train.csv')
df.head(2)

Unnamed: 0,image_link,group_id,entity_name,entity_value
0,https://m.media-amazon.com/images/I/61I9XdN6OF...,748919,item_weight,500.0 gram
1,https://m.media-amazon.com/images/I/71gSRbyXmo...,916768,item_volume,1.0 cup


In [28]:
# Prepare entity names and group IDs
entity_to_index = {entity: idx for idx, entity in enumerate(df['entity_name'].unique())}
group_to_index = {group: idx for idx, group in enumerate(df['group_id'].unique())}

In [29]:
df['entity_name_index'] = df['entity_name'].map(entity_to_index)
df['group_id_index'] = df['group_id'].map(group_to_index)

In [30]:
# Prepare unit list
unit_list = df['entity_value'].str.split().str[-1].unique().tolist()

In [32]:
# Split the data
train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)

In [33]:
tokenizer = OutputTokenizer(unit_list)

In [53]:
# Create datasets
train_dataset = ProductImageDataset(train_df, tokenizer)
val_dataset = ProductImageDataset(val_df, tokenizer)

In [54]:
# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)


In [55]:
# Initialize the model
model = EntityPredictorVisionBasedModel(
    num_entity_names=len(entity_to_index),
    num_group_ids=len(group_to_index),
    unit_list=unit_list
).to(device)

In [56]:
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.PAD_token)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [57]:
num_epochs = 3

In [None]:
train_loader

In [58]:
for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for img, entity_name, group_id, target in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        img = img.to(device)
        entity_name = entity_name.to(device)
        group_id = group_id.to(device)
        target = target.to(device)

        optimizer.zero_grad()
        output_probs = model(img, entity_name, group_id, target)
        loss = criterion(output_probs.view(-1, output_probs.size(-1)), target.view(-1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")

    # Validation
    model.eval()
    val_loss = 0
    correct_predictions = 0
    total_predictions = 0

    with torch.no_grad():
        for img, entity_name, group_id, target in tqdm(val_loader, desc="Validation"):
            img = img.to(device)
            entity_name = entity_name.to(device)
            group_id = group_id.to(device)
            target = target.to(device)

            output_probs = model(img, entity_name, group_id)
            loss = criterion(output_probs.view(-1, output_probs.size(-1)), target.view(-1))
            val_loss += loss.item()

            # Calculate accuracy
            predictions = output_probs.argmax(dim=-1)
            correct_predictions += (predictions == target).all(dim=1).sum().item()
            total_predictions += target.size(0)

    avg_val_loss = val_loss / len(val_loader)
    accuracy = correct_predictions / total_predictions
    print(f"Validation Loss: {avg_val_loss:.4f}, Accuracy: {accuracy:.4f}")

Epoch 1/3:   0%|          | 0/6597 [00:00<?, ?it/s]




































Epoch 1/3:   0%|          | 0/6597 [00:14<?, ?it/s]


RuntimeError: shape '[32, 8, 96]' is invalid for input of size 4841472