graph TD
    A[Input Image] --> B[Vision Encoder]
    C[Text Query] --> D[Text Encoder]
    B --> E[Cross-Attention Transformer]
    D --> E
    E --> F[Decoder]
    F --> G[Output Text]

In [1]:
import torch
import torch.nn as nn
from transformers import ViTModel

import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.nn.utils.rnn import pad_sequence

from PIL import Image
import requests
from io import BytesIO

import pandas as pd
import numpy as np
import re
from typing import List, Tuple
from tqdm import tqdm

from sklearn.model_selection import train_test_split

## Model & Supporter Classes

In [2]:
class OutputTokenizer:
    def __init__(self, unit_list: List[str]):
        self.unit_list = unit_list
        self.unit_to_id = {unit: i for i, unit in enumerate(unit_list)}
        self.id_to_unit = {i: unit for i, unit in enumerate(unit_list)}
        self.num_units = len(unit_list)

        # Special tokens
        self.PAD_token = 0
        self.SOS_token = 1
        self.EOS_token = 2
        self.UNK_token = 3 # for Unknown tokens

        # Vocabulary size: special tokens + digits + decimal point + units
        self.vocab_size = 4 + 10 + 1 + self.num_units

    def tokenize(self, text: str, max_length: int = 50) -> List[int]:
        # Split the input into number and unit
        match = re.match(r'(\d+\.?\d*)\s*(\w+)', text.strip())
        if not match:
            return [self.SOS_token, self.UNK_token, self.EOS_token] + [self.PAD_token] * (max_length - 3)

        number, unit = match.groups()

        # Tokenize the number
        number_tokens = [int(digit) + 4 for digit in number if digit.isdigit()]
        if '.' in number:
            number_tokens.insert(number.index('.'), 14)  # 14 is the token for decimal point

        # Tokenize the unit
        unit_token = self.unit_to_id.get(unit, self.UNK_token)

        # Combine tokens
        tokens = [self.SOS_token] + number_tokens + [unit_token + 15] + [self.EOS_token]

        # Pad or truncate to max_length
        if len(tokens) < max_length:
            tokens += [self.PAD_token] * (max_length - len(tokens))
        else:
            tokens = tokens[:max_length-1] + [self.EOS_token]

        return tokens

    def detokenize(self, tokens: List[int]) -> str:
        number = ''
        unit = ''
        for token in tokens:
            if token == self.SOS_token or token == self.PAD_token:
                continue
            elif token == self.EOS_token:
                break
            elif 4 <= token <= 13:
                number += str(token - 4)
            elif token == 14:
                number += '.'
            elif token >= 15:
                unit = self.id_to_unit.get(token - 15, 'UNK')

        return f"{number} {unit}"

    def decode(self, token_ids: List[int]) -> str:
        return self.detokenize(token_ids)


In [29]:
class EntityPredictorVisionBasedModel(nn.Module):
    def __init__(self, num_entity_names, num_group_ids, vocab_size, max_length=50):
        super().__init__()
        self.image_encoder = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')

        self.entity_embedding = nn.Embedding(num_entity_names, 768)
        self.group_embedding = nn.Embedding(num_group_ids, 768)

        self.cross_attention = nn.MultiheadAttention(embed_dim=768, num_heads=8)

        self.decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model=768, nhead=8),
            num_layers=6
        )

        self.output_layer = nn.Linear(768, vocab_size)
        self.embedding = nn.Embedding(vocab_size, 768)

        self.max_length = max_length

    def forward(self, images, entity_names, group_ids, target_texts=None):
        # Encode images
        image_features = self.image_encoder(images).last_hidden_state

        # Encode entity names and group IDs
        entity_embeddings = self.entity_embedding(entity_names)
        group_embeddings = self.group_embedding(group_ids)

        # Combine entity and group embeddings
        query = entity_embeddings + group_embeddings
        query = query.unsqueeze(1).repeat(1, self.max_length, 1)

        # Cross-attention
        image_features = image_features.transpose(0, 1)
        query = query.transpose(0, 1)
        attended_features, _ = self.cross_attention(query, image_features, image_features)

        batch_size = images.size(0)
        seq_length = self.max_length

        if self.training and target_texts is not None:
            # During training, use teacher forcing
            decoder_inputs = self.embedding(target_texts[:, :-1])
            decoder_outputs = self.decode(decoder_inputs, attended_features)
        else:
            # During validation/inference, generate the full sequence
            decoder_input = torch.zeros((batch_size, 1, 768), device=images.device)
            decoder_outputs = []

            for _ in range(seq_length):
                step_output = self.decode(decoder_input, attended_features)
                decoder_outputs.append(step_output)
                next_token = step_output.argmax(dim=-1)
                decoder_input = self.embedding(next_token)

            decoder_outputs = torch.cat(decoder_outputs, dim=1)

        return decoder_outputs

    def decode(self, decoder_input, attended_features):
        # Transpose for decoder
        decoder_input = decoder_input.transpose(0, 1)

        # Decode
        decoder_output = self.decoder(decoder_input, attended_features)

        # Generate output probabilities
        output_probs = self.output_layer(decoder_output.transpose(0, 1))

        return output_probs

    def generate(self, images, entity_names, group_ids):
        self.eval()
        with torch.no_grad():
            output_probs = self(images, entity_names, group_ids)
            generated_tokens = output_probs.argmax(dim=-1)
        return generated_tokens

In [4]:
# Helper functions
def entity_to_index(entity_names):
    unique_entities = sorted(set(entity_names))
    entity_to_idx = {entity: idx for idx, entity in enumerate(unique_entities)}
    return [entity_to_idx[entity] for entity in entity_names]

def group_to_index(group_ids):
    unique_groups = sorted(set(group_ids))
    group_to_idx = {group: idx for idx, group in enumerate(unique_groups)}
    return [group_to_idx[group] for group in group_ids]


In [5]:
class ProductImageDataset(Dataset):
    def __init__(self, df, tokenizer, transform=None, max_length=50):
        self.df = df
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        try:
            row = self.df.iloc[idx]

            # Load image from URL
            response = requests.get(row['image_link'], timeout=10)
            response.raise_for_status()
            img = Image.open(BytesIO(response.content)).convert('RGB')
            img = self.transform(img)

            # Prepare other inputs
            entity_name = torch.tensor(row['entity_name_index'], dtype=torch.long)
            group_id = torch.tensor(row['group_id_index'], dtype=torch.long)

            # Prepare target
            target = row['entity_value']
            target = self.tokenizer.tokenize(target, max_length=self.max_length)
            target = torch.tensor(target, dtype=torch.long)

            return img, entity_name, group_id, target

        except Exception as e:
            print(f"Error loading data at index {idx}: {str(e)}")
            # Return a default value with correct shapes
            return torch.zeros((3, 224, 224)), torch.tensor(0, dtype=torch.long), torch.tensor(0, dtype=torch.long), torch.tensor([0] * self.max_length, dtype=torch.long)

## Training

In [6]:
# Set random seed for reproducibility
torch.manual_seed(42)

<torch._C.Generator at 0x7c956c3809f0>

In [7]:
# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [8]:
df = pd.read_csv('https://raw.githubusercontent.com/Dhyanesh-Panchal/amazon_ml_2024_deepinsight/refs/heads/master/student_resource%203/dataset/filtered_train.csv')
df.head(2)

Unnamed: 0,index,image_link,group_id,entity_name,entity_value
0,0,https://m.media-amazon.com/images/I/61I9XdN6OF...,748919,item_weight,500.0 gram
1,1,https://m.media-amazon.com/images/I/71gSRbyXmo...,916768,item_volume,1.0 cup


In [9]:
df = df.dropna(subset=['entity_value'])

### Lets sample a small subset

In [10]:
train_batch_no = 1
train_batch_size = 100

In [11]:
df = df[(train_batch_no-1)*train_batch_size : (train_batch_no)*train_batch_size]

In [12]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 100 entries, 0 to 99
Data columns (total 5 columns):
 #   Column        Non-Null Count  Dtype 
---  ------        --------------  ----- 
 0   index         100 non-null    int64 
 1   image_link    100 non-null    object
 2   group_id      100 non-null    int64 
 3   entity_name   100 non-null    object
 4   entity_value  100 non-null    object
dtypes: int64(2), object(3)
memory usage: 4.0+ KB


In [13]:
# Prepare entity names and group IDs
entity_to_index = {entity: idx for idx, entity in enumerate(df['entity_name'].unique())}
group_to_index = {group: idx for idx, group in enumerate(df['group_id'].unique())}

In [14]:
df['entity_name_index'] = df['entity_name'].map(entity_to_index)
df['group_id_index'] = df['group_id'].map(group_to_index)

In [15]:
# Prepare unit list
df['entity_value_unit'] = df['entity_value'].apply(lambda x: " ".join(x.split(" ")[1:]))
unit_list = df['entity_value_unit'].unique().tolist()

In [16]:
unit_list

['gram',
 'cup',
 'milligram',
 'kilogram',
 'ounce',
 'gallon',
 'volt',
 'watt',
 'pound',
 'millilitre',
 'cubic foot']

In [17]:
# Split the data
train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)

In [18]:
tokenizer = OutputTokenizer(unit_list)

In [19]:
train_dataset = ProductImageDataset(train_df, tokenizer, max_length=50)
val_dataset = ProductImageDataset(val_df, tokenizer, max_length=50)

In [20]:
train_dataset[3][0].shape

torch.Size([3, 224, 224])

In [21]:
def custom_collate(batch):
    # Separate the batch into individual components
    images, entity_names, group_ids, targets = zip(*batch)

    # Stack images, entity_names, and group_ids (assuming they're already tensors of uniform size)
    images = torch.stack(images, 0)
    entity_names = torch.stack(entity_names, 0)
    group_ids = torch.stack(group_ids, 0)

    # Pad the target sequences
    targets = pad_sequence(targets, batch_first=True, padding_value=0)

    return images, entity_names, group_ids, targets

In [22]:
# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4,collate_fn=custom_collate)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4,collate_fn=custom_collate)




In [30]:
# Initialise the Model
model = EntityPredictorVisionBasedModel(
    num_entity_names=len(entity_to_index),
    num_group_ids=len(group_to_index),
    vocab_size=tokenizer.vocab_size,
    max_length=50
).to(device)

In [31]:
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device):
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0

        for batch_idx, (img, entity_name, group_id, target) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")):
            img = img.to(device)
            entity_name = entity_name.to(device)
            group_id = group_id.to(device)
            target = target.to(device)

            optimizer.zero_grad()
            output_probs = model(img, entity_name, group_id, target)

            # Reshape output_probs and target for loss calculation
            output_probs = output_probs.view(-1, output_probs.size(-1))
            target = target[:, 1:].contiguous().view(-1)  # Shift target by 1 and flatten

            # Create a mask to ignore padding in loss calculation
            mask = (target != 0).float()

            loss = criterion(output_probs, target)
            loss = (loss * mask).sum() / mask.sum()  # Average loss over non-pad tokens

            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            if batch_idx % 10 == 0:
                print(f"Batch {batch_idx}, Loss: {loss.item():.4f}")
                print(f"Output probs shape: {output_probs.shape}")
                print(f"Target shape: {target.shape}")

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")

        # Validation
        model.eval()
        val_loss = 0
        correct_predictions = 0
        total_predictions = 0

        with torch.no_grad():
            for img, entity_name, group_id, target in tqdm(val_loader, desc="Validation"):
                img = img.to(device)
                entity_name = entity_name.to(device)
                group_id = group_id.to(device)
                target = target.to(device)

                print(f"Validation batch shapes:")
                print(f"img: {img.shape}")
                print(f"entity_name: {entity_name.shape}")
                print(f"group_id: {group_id.shape}")
                print(f"target: {target.shape}")

                output_probs = model(img, entity_name, group_id)
                print(f"output_probs: {output_probs.shape}")

                # Reshape output_probs to match target size
                output_probs = output_probs.squeeze(1)  # Remove the extra dimension
                output_probs = output_probs.view(-1, output_probs.size(-1))
                target = target.view(-1)

                print(f"Reshaped output_probs: {output_probs.shape}")
                print(f"Reshaped target: {target.shape}")

                # Create a mask to ignore padding in loss calculation
                mask = (target != 0).float()

                loss = criterion(output_probs, target)
                loss = (loss * mask).sum() / mask.sum()  # Average loss over non-pad tokens
                val_loss += loss.item()

                # Calculate accuracy
                predictions = output_probs.argmax(dim=-1)
                correct_predictions += ((predictions == target) * mask).sum().item()
                total_predictions += mask.sum().item()

        avg_val_loss = val_loss / len(val_loader)
        accuracy = correct_predictions / total_predictions
        print(f"Validation Loss: {avg_val_loss:.4f}, Accuracy: {accuracy:.4f}")

    return model

In [25]:
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.PAD_token)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [26]:
num_epochs = 3

In [32]:
model = train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device)

Epoch 1/3:  33%|███▎      | 1/3 [00:04<00:09,  4.75s/it]

Batch 0, Loss: 3.3417
Output probs shape: torch.Size([1568, 26])
Target shape: torch.Size([1568])


Epoch 1/3: 100%|██████████| 3/3 [00:06<00:00,  2.26s/it]


Epoch 1/3, Average Loss: 3.3396


Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Validation batch shapes:
img: torch.Size([20, 3, 224, 224])
entity_name: torch.Size([20])
group_id: torch.Size([20])
target: torch.Size([20, 50])


Validation: 100%|██████████| 1/1 [00:02<00:00,  2.83s/it]


output_probs: torch.Size([20, 50, 26])
Reshaped output_probs: torch.Size([1000, 26])
Reshaped target: torch.Size([1000])
Validation Loss: 3.5888, Accuracy: 0.0350


Epoch 2/3:  33%|███▎      | 1/3 [00:04<00:09,  4.66s/it]

Batch 0, Loss: 3.3547
Output probs shape: torch.Size([1568, 26])
Target shape: torch.Size([1568])


Epoch 2/3: 100%|██████████| 3/3 [00:06<00:00,  2.25s/it]


Epoch 2/3, Average Loss: 3.3213


Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Validation batch shapes:
img: torch.Size([20, 3, 224, 224])
entity_name: torch.Size([20])
group_id: torch.Size([20])
target: torch.Size([20, 50])


Validation: 100%|██████████| 1/1 [00:03<00:00,  3.31s/it]


output_probs: torch.Size([20, 50, 26])
Reshaped output_probs: torch.Size([1000, 26])
Reshaped target: torch.Size([1000])
Validation Loss: 3.5888, Accuracy: 0.0350


Epoch 3/3:  33%|███▎      | 1/3 [00:06<00:12,  6.23s/it]

Batch 0, Loss: 3.3289
Output probs shape: torch.Size([1568, 26])
Target shape: torch.Size([1568])


Epoch 3/3: 100%|██████████| 3/3 [00:08<00:00,  2.75s/it]


Epoch 3/3, Average Loss: 3.3251


Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Validation batch shapes:
img: torch.Size([20, 3, 224, 224])
entity_name: torch.Size([20])
group_id: torch.Size([20])
target: torch.Size([20, 50])


Validation: 100%|██████████| 1/1 [00:02<00:00,  2.02s/it]

output_probs: torch.Size([20, 50, 26])
Reshaped output_probs: torch.Size([1000, 26])
Reshaped target: torch.Size([1000])
Validation Loss: 3.5888, Accuracy: 0.0350





## Save Model

In [28]:
save_path = "model_v1.pth"

In [29]:
torch.save(model.state_dict(), save_path)
print(f"Model saved to {save_path}")

Model saved to model_v1.pth


## Evaluation and dumping the Model

In [None]:
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import json

def save_model_and_calculate_metrics(model, test_loader, tokenizer, device, save_path='trained_model.pth'):
    # Save the model
    torch.save(model.state_dict(), save_path)
    print(f"Model saved to {save_path}")

    # Set model to evaluation mode
    model.eval()

    all_predictions = []
    all_targets = []

    # Disable gradient calculations
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating"):
            images, entity_names, group_ids, targets = batch
            images = images.to(device)
            entity_names = entity_names.to(device)
            group_ids = group_ids.to(device)

            # Generate predictions
            outputs = model.generate(images, entity_names, group_ids)

            # Convert predictions and targets to text
            pred_texts = [tokenizer.decode(pred) for pred in outputs.cpu().numpy()]
            target_texts = [tokenizer.decode(target) for target in targets.cpu().numpy()]

            all_predictions.extend(pred_texts)
            all_targets.extend(target_texts)

    # Calculate metrics
    accuracy = accuracy_score(all_targets, all_predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(all_targets, all_predictions, average='weighted')

    # Print metrics
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")

    # Save metrics to a file
    metrics = {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1_score": f1
    }
    with open('metrics.json', 'w') as f:
        json.dump(metrics, f, indent=4)
    print("Metrics saved to metrics.json")

In [None]:
 # Load your test dataset
test_df = pd.read_csv('https://raw.githubusercontent.com/Dhyanesh-Panchal/amazon_ml_2024_deepinsight/refs/heads/master/student_resource%203/dataset/test.csv')
test_df = test_df.sample(frac=0.1, random_state=42)

In [None]:
test_df

Unnamed: 0,index,image_link,group_id,entity_name
81509,81578,https://m.media-amazon.com/images/I/610apmjypy...,311997,height
120375,120463,https://m.media-amazon.com/images/I/71WGncjrqQ...,355666,item_weight
89356,89428,https://m.media-amazon.com/images/I/6158kNWFcS...,407808,item_weight
53381,53428,https://m.media-amazon.com/images/I/51Zj1+VgNT...,267482,item_weight
40962,41007,https://m.media-amazon.com/images/I/51M9bDawya...,245652,width
...,...,...,...,...
23476,23506,https://m.media-amazon.com/images/I/5131GDDywM...,242256,depth
68929,68990,https://m.media-amazon.com/images/I/51qjw2sj1o...,683885,height
94818,94892,https://m.media-amazon.com/images/I/618MHecxNz...,656960,depth
5374,5384,https://m.media-amazon.com/images/I/41CtkPNJYs...,709627,voltage
