<a href="https://colab.research.google.com/github/Dhyanesh-Panchal/amazon_ml_2024_deepinsight/blob/master/solution/Multi_input_endoder_decoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

graph TD
    A[Input Image] --> B[Vision Encoder]
    C[Text Query] --> D[Text Encoder]
    B --> E[Cross-Attention Transformer]
    D --> E
    E --> F[Decoder]
    F --> G[Output Text]

In [1]:
import torch
import torch.nn as nn
from transformers import ViTModel

import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.nn.utils.rnn import pad_sequence

from PIL import Image
import requests
from io import BytesIO

import pandas as pd
import numpy as np
import re
from typing import List, Tuple
from tqdm import tqdm

from sklearn.model_selection import train_test_split

## Model & Supporter Classes

In [2]:
class OutputTokenizer:
    def __init__(self, unit_list: List[str]):
        self.unit_list = unit_list
        self.unit_to_id = {unit: i for i, unit in enumerate(unit_list)}
        self.id_to_unit = {i: unit for i, unit in enumerate(unit_list)}
        self.num_units = len(unit_list)

        # Special tokens
        self.PAD_token = 0
        self.SOS_token = 1
        self.EOS_token = 2
        self.UNK_token = 3 # for Unknown tokens

        # Vocabulary size: special tokens + digits + decimal point + units
        self.vocab_size = 4 + 10 + 1 + self.num_units

    def tokenize(self, text: str, max_length: int = 50) -> List[int]:
        # Split the input into number and unit
        match = re.match(r'(\d+\.?\d*)\s*(\w+)', text.strip())
        if not match:
            return [self.SOS_token, self.UNK_token, self.EOS_token] + [self.PAD_token] * (max_length - 3)

        number, unit = match.groups()

        # Tokenize the number
        number_tokens = [int(digit) + 4 for digit in number if digit.isdigit()]
        if '.' in number:
            number_tokens.insert(number.index('.'), 14)  # 14 is the token for decimal point

        # Tokenize the unit
        unit_token = self.unit_to_id.get(unit, self.UNK_token)

        # Combine tokens
        tokens = [self.SOS_token] + number_tokens + [unit_token + 15] + [self.EOS_token]

        # Pad or truncate to max_length
        if len(tokens) < max_length:
            tokens += [self.PAD_token] * (max_length - len(tokens))
        else:
            tokens = tokens[:max_length-1] + [self.EOS_token]

        return tokens

    def detokenize(self, tokens: List[int]) -> str:
        number = ''
        unit = ''
        for token in tokens:
            if token == self.SOS_token or token == self.PAD_token:
                continue
            elif token == self.EOS_token:
                break
            elif 4 <= token <= 13:
                number += str(token - 4)
            elif token == 14:
                number += '.'
            elif token >= 15:
                unit = self.id_to_unit.get(token - 15, 'UNK')

        return f"{number} {unit}"

    def decode(self, token_ids: List[int]) -> str:
        return self.detokenize(token_ids)


In [3]:
class EntityPredictorVisionBasedModel(nn.Module):
    def __init__(self, num_entity_names, num_group_ids, vocab_size, max_length=50):
        super().__init__()
        self.image_encoder = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')

        self.entity_embedding = nn.Embedding(num_entity_names, 768)
        self.group_embedding = nn.Embedding(num_group_ids, 768)

        self.cross_attention = nn.MultiheadAttention(embed_dim=768, num_heads=8)

        self.decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model=768, nhead=8),
            num_layers=6
        )

        self.output_layer = nn.Linear(768, vocab_size)
        self.embedding = nn.Embedding(vocab_size, 768)

        self.max_length = max_length

    def forward(self, images, entity_names, group_ids, target_texts=None):
        # Encode images
        image_features = self.image_encoder(images).last_hidden_state

        # Encode entity names and group IDs
        entity_embeddings = self.entity_embedding(entity_names)
        group_embeddings = self.group_embedding(group_ids)

        # Combine entity and group embeddings
        query = entity_embeddings + group_embeddings
        query = query.unsqueeze(1).repeat(1, self.max_length, 1)

        # Cross-attention
        image_features = image_features.transpose(0, 1)
        query = query.transpose(0, 1)
        attended_features, _ = self.cross_attention(query, image_features, image_features)

        batch_size = images.size(0)
        seq_length = self.max_length

        if self.training and target_texts is not None:
            # During training, use teacher forcing
            decoder_inputs = self.embedding(target_texts[:, :-1])
            decoder_outputs = self.decode(decoder_inputs, attended_features)
        else:
            # During validation/inference, generate the full sequence
            decoder_input = torch.zeros((batch_size, 1, 768), device=images.device)
            decoder_outputs = []

            for _ in range(seq_length):
                step_output = self.decode(decoder_input, attended_features)
                decoder_outputs.append(step_output)
                next_token = step_output.argmax(dim=-1)
                decoder_input = self.embedding(next_token)

            decoder_outputs = torch.cat(decoder_outputs, dim=1)

        return decoder_outputs

    def decode(self, decoder_input, attended_features):
        # Transpose for decoder
        decoder_input = decoder_input.transpose(0, 1)

        # Decode
        decoder_output = self.decoder(decoder_input, attended_features)

        # Generate output probabilities
        output_probs = self.output_layer(decoder_output.transpose(0, 1))

        return output_probs

    def generate(self, images, entity_names, group_ids):
        self.eval()
        with torch.no_grad():
            output_probs = self(images, entity_names, group_ids)
            generated_tokens = output_probs.argmax(dim=-1)
        return generated_tokens

In [4]:
# Helper functions
def entity_to_index(entity_names):
    unique_entities = sorted(set(entity_names))
    entity_to_idx = {entity: idx for idx, entity in enumerate(unique_entities)}
    return [entity_to_idx[entity] for entity in entity_names]

def group_to_index(group_ids):
    unique_groups = sorted(set(group_ids))
    group_to_idx = {group: idx for idx, group in enumerate(unique_groups)}
    return [group_to_idx[group] for group in group_ids]


In [5]:
class ProductImageDataset(Dataset):
    def __init__(self, df, tokenizer, transform=None, max_length=50):
        self.df = df
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        try:
            row = self.df.iloc[idx]

            # Load image from URL
            response = requests.get(row['image_link'], timeout=10)
            response.raise_for_status()
            img = Image.open(BytesIO(response.content)).convert('RGB')
            img = self.transform(img)

            # Prepare other inputs
            entity_name = torch.tensor(row['entity_name_index'], dtype=torch.long)
            group_id = torch.tensor(row['group_id_index'], dtype=torch.long)

            # Prepare target
            target = row['entity_value']
            target = self.tokenizer.tokenize(target, max_length=self.max_length)
            target = torch.tensor(target, dtype=torch.long)

            return img, entity_name, group_id, target

        except Exception as e:
            print(f"Error loading data at index {idx}: {str(e)}")
            # Return a default value with correct shapes
            return torch.zeros((3, 224, 224)), torch.tensor(0, dtype=torch.long), torch.tensor(0, dtype=torch.long), torch.tensor([0] * self.max_length, dtype=torch.long)

## Training

In [6]:
# Set random seed for reproducibility
torch.manual_seed(42)

<torch._C.Generator at 0x7edde43a09f0>

In [7]:
# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [8]:
df = pd.read_csv('https://raw.githubusercontent.com/Dhyanesh-Panchal/amazon_ml_2024_deepinsight/refs/heads/master/student_resource%203/dataset/filtered_train.csv')
df.head(2)

Unnamed: 0,index,image_link,group_id,entity_name,entity_value
0,0,https://m.media-amazon.com/images/I/61I9XdN6OF...,748919,item_weight,500.0 gram
1,1,https://m.media-amazon.com/images/I/71gSRbyXmo...,916768,item_volume,1.0 cup


In [9]:
df = df.dropna(subset=['entity_value'])

### Lets sample a small subset

In [10]:
train_batch_no = 1
train_batch_size = 25000

In [11]:
df = df[(train_batch_no-1)*train_batch_size : (train_batch_no)*train_batch_size]

In [12]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 25000 entries, 0 to 24999
Data columns (total 5 columns):
 #   Column        Non-Null Count  Dtype 
---  ------        --------------  ----- 
 0   index         25000 non-null  int64 
 1   image_link    25000 non-null  object
 2   group_id      25000 non-null  int64 
 3   entity_name   25000 non-null  object
 4   entity_value  25000 non-null  object
dtypes: int64(2), object(3)
memory usage: 976.7+ KB


In [13]:
# Prepare entity names and group IDs
entity_to_index = {entity: idx for idx, entity in enumerate(df['entity_name'].unique())}
group_to_index = {group: idx for idx, group in enumerate(df['group_id'].unique())}

In [14]:
df['entity_name_index'] = df['entity_name'].map(entity_to_index)
df['group_id_index'] = df['group_id'].map(group_to_index)

In [15]:
# Prepare unit list
df['entity_value_unit'] = df['entity_value'].apply(lambda x: " ".join(x.split(" ")[1:]))
unit_list = df['entity_value_unit'].unique().tolist()

In [16]:
unit_list

['gram',
 'cup',
 'milligram',
 'kilogram',
 'ounce',
 'gallon',
 'volt',
 'watt',
 'pound',
 'millilitre',
 'cubic foot',
 'fluid ounce',
 'ton',
 'decilitre',
 'cubic inch',
 'litre',
 'microgram',
 'centimetre',
 'quart',
 'horsepower',
 'kilowatt',
 'kilowatt hour',
 'gigabyte',
 'millimetre',
 'pint',
 'centilitre',
 'candela',
 'inch',
 'person',
 'metre',
 'foot']

In [17]:
# Split the data
train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)

In [18]:
tokenizer = OutputTokenizer(unit_list)

In [19]:
train_dataset = ProductImageDataset(train_df, tokenizer, max_length=50)
val_dataset = ProductImageDataset(val_df, tokenizer, max_length=50)

In [20]:
train_dataset[3][0].shape

torch.Size([3, 224, 224])

In [21]:
def custom_collate(batch):
    # Separate the batch into individual components
    images, entity_names, group_ids, targets = zip(*batch)

    # Stack images, entity_names, and group_ids (assuming they're already tensors of uniform size)
    images = torch.stack(images, 0)
    entity_names = torch.stack(entity_names, 0)
    group_ids = torch.stack(group_ids, 0)

    # Pad the target sequences
    targets = pad_sequence(targets, batch_first=True, padding_value=0)

    return images, entity_names, group_ids, targets

In [22]:
# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4,collate_fn=custom_collate)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4,collate_fn=custom_collate)




In [23]:
# Initialise the Model
model = EntityPredictorVisionBasedModel(
    num_entity_names=len(entity_to_index),
    num_group_ids=len(group_to_index),
    vocab_size=tokenizer.vocab_size,
    max_length=50
).to(device)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [24]:
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device):
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0

        for batch_idx, (img, entity_name, group_id, target) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")):
            img = img.to(device)
            entity_name = entity_name.to(device)
            group_id = group_id.to(device)
            target = target.to(device)

            optimizer.zero_grad()
            output_probs = model(img, entity_name, group_id, target)

            # Reshape output_probs and target for loss calculation
            output_probs = output_probs.view(-1, output_probs.size(-1))
            target = target[:, 1:].contiguous().view(-1)  # Shift target by 1 and flatten

            # Create a mask to ignore padding in loss calculation
            mask = (target != 0).float()

            loss = criterion(output_probs, target)
            loss = (loss * mask).sum() / mask.sum()  # Average loss over non-pad tokens

            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            if batch_idx % 10 == 0:
                print(f"Batch {batch_idx}, Loss: {loss.item():.4f}")
                print(f"Output probs shape: {output_probs.shape}")
                print(f"Target shape: {target.shape}")

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")

        # Validation
        model.eval()
        val_loss = 0
        correct_predictions = 0
        total_predictions = 0

        with torch.no_grad():
            for img, entity_name, group_id, target in tqdm(val_loader, desc="Validation"):
                img = img.to(device)
                entity_name = entity_name.to(device)
                group_id = group_id.to(device)
                target = target.to(device)

                print(f"Validation batch shapes:")
                print(f"img: {img.shape}")
                print(f"entity_name: {entity_name.shape}")
                print(f"group_id: {group_id.shape}")
                print(f"target: {target.shape}")

                output_probs = model(img, entity_name, group_id)
                print(f"output_probs: {output_probs.shape}")

                # Reshape output_probs to match target size
                output_probs = output_probs.squeeze(1)  # Remove the extra dimension
                output_probs = output_probs.view(-1, output_probs.size(-1))
                target = target.view(-1)

                print(f"Reshaped output_probs: {output_probs.shape}")
                print(f"Reshaped target: {target.shape}")

                # Create a mask to ignore padding in loss calculation
                mask = (target != 0).float()

                loss = criterion(output_probs, target)
                loss = (loss * mask).sum() / mask.sum()  # Average loss over non-pad tokens
                val_loss += loss.item()

                # Calculate accuracy
                predictions = output_probs.argmax(dim=-1)
                correct_predictions += ((predictions == target) * mask).sum().item()
                total_predictions += mask.sum().item()

        avg_val_loss = val_loss / len(val_loader)
        accuracy = correct_predictions / total_predictions
        print(f"Validation Loss: {avg_val_loss:.4f}, Accuracy: {accuracy:.4f}")

    return model

In [25]:
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.PAD_token)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [33]:
num_epochs = 1

In [35]:
model = train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device)

Epoch 1/1:   0%|          | 1/625 [00:07<1:23:10,  8.00s/it]

Batch 0, Loss: 0.3094
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:   2%|▏         | 11/625 [00:22<15:01,  1.47s/it]

Batch 10, Loss: 0.4889
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:   3%|▎         | 21/625 [00:40<20:23,  2.03s/it]

Batch 20, Loss: 0.4547
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:   5%|▍         | 31/625 [00:54<13:40,  1.38s/it]

Batch 30, Loss: 0.4979
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:   7%|▋         | 41/625 [01:11<14:41,  1.51s/it]

Batch 40, Loss: 0.4866
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:   8%|▊         | 49/625 [01:24<15:45,  1.64s/it]

Error loading data at index 12976: image file is truncated (2 bytes not processed)


Epoch 1/1:   8%|▊         | 51/625 [01:26<14:04,  1.47s/it]

Batch 50, Loss: 0.4699
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  10%|▉         | 61/625 [01:43<14:52,  1.58s/it]

Batch 60, Loss: 0.3827
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  11%|█▏        | 71/625 [01:59<19:16,  2.09s/it]

Batch 70, Loss: 0.3533
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  13%|█▎        | 81/625 [02:14<13:09,  1.45s/it]

Batch 80, Loss: 0.5994
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  15%|█▍        | 91/625 [02:33<13:59,  1.57s/it]

Batch 90, Loss: 0.3259
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  16%|█▌        | 101/625 [02:50<13:24,  1.53s/it]

Batch 100, Loss: 0.3981
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  18%|█▊        | 111/625 [03:08<17:38,  2.06s/it]

Batch 110, Loss: 0.2854
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  19%|█▉        | 121/625 [03:21<11:28,  1.37s/it]

Batch 120, Loss: 0.4503
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  21%|██        | 131/625 [03:40<13:00,  1.58s/it]

Batch 130, Loss: 0.4058
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  23%|██▎       | 141/625 [03:56<12:34,  1.56s/it]

Batch 140, Loss: 0.5302
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  24%|██▍       | 151/625 [04:11<14:08,  1.79s/it]

Batch 150, Loss: 0.6492
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  26%|██▌       | 160/625 [04:24<10:45,  1.39s/it]

Error loading data at index 4498: 400 Client Error: Bad Request for url: https://m.media-amazon.com/images/I/1yw53vfQtS.jpg


Epoch 1/1:  26%|██▌       | 161/625 [04:25<10:36,  1.37s/it]

Batch 160, Loss: 0.4478
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  27%|██▋       | 171/625 [04:44<12:35,  1.66s/it]

Batch 170, Loss: 0.5048
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  29%|██▉       | 181/625 [05:01<12:26,  1.68s/it]

Batch 180, Loss: 0.4147
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  31%|███       | 191/625 [05:16<14:12,  1.97s/it]

Batch 190, Loss: 0.4275
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  32%|███▏      | 201/625 [05:32<10:03,  1.42s/it]

Batch 200, Loss: 0.4193
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  34%|███▍      | 211/625 [05:49<12:50,  1.86s/it]

Batch 210, Loss: 0.3050
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  35%|███▌      | 221/625 [06:04<10:47,  1.60s/it]

Batch 220, Loss: 0.4817
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  37%|███▋      | 231/625 [06:21<10:33,  1.61s/it]

Batch 230, Loss: 0.4991
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  39%|███▊      | 241/625 [06:38<09:34,  1.50s/it]

Batch 240, Loss: 0.4649
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  40%|████      | 251/625 [06:56<12:47,  2.05s/it]

Batch 250, Loss: 0.3987
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  40%|████      | 253/625 [06:59<10:27,  1.69s/it]

Error loading data at index 11406: 404 Client Error: Not Found for url: https://m.media-amazon.com/images/I/lwd2cSmT2ux.jpg


Epoch 1/1:  42%|████▏     | 261/625 [07:13<09:41,  1.60s/it]

Batch 260, Loss: 0.4035
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  43%|████▎     | 271/625 [07:29<08:20,  1.41s/it]

Batch 270, Loss: 0.4025
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  45%|████▍     | 281/625 [07:46<09:11,  1.60s/it]

Batch 280, Loss: 0.3713
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  47%|████▋     | 291/625 [08:02<10:11,  1.83s/it]

Batch 290, Loss: 0.4658
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  48%|████▊     | 301/625 [08:17<07:28,  1.38s/it]

Batch 300, Loss: 0.4555
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  50%|████▉     | 311/625 [08:34<08:34,  1.64s/it]

Batch 310, Loss: 0.3623
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  51%|█████▏    | 321/625 [08:50<07:56,  1.57s/it]

Batch 320, Loss: 0.4319
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  53%|█████▎    | 331/625 [09:07<09:37,  1.97s/it]

Batch 330, Loss: 0.4243
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  55%|█████▍    | 341/625 [09:21<06:29,  1.37s/it]

Batch 340, Loss: 0.3742
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  56%|█████▌    | 351/625 [09:40<08:44,  1.91s/it]

Batch 350, Loss: 0.4082
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  57%|█████▋    | 354/625 [09:47<11:02,  2.44s/it]

Error loading data at index 16690: 400 Client Error: Bad Request for url: https://m.media-amazon.com/images/I/DzP2RMRQO0.jpg


Epoch 1/1:  58%|█████▊    | 361/625 [09:59<07:49,  1.78s/it]

Batch 360, Loss: 0.3930
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  59%|█████▉    | 371/625 [10:14<06:23,  1.51s/it]

Batch 370, Loss: 0.3597
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  61%|██████    | 381/625 [10:27<05:29,  1.35s/it]

Batch 380, Loss: 0.3825
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  63%|██████▎   | 391/625 [10:44<05:31,  1.41s/it]

Batch 390, Loss: 0.8249
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  64%|██████▍   | 401/625 [11:00<05:47,  1.55s/it]

Batch 400, Loss: 0.2486
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  66%|██████▌   | 411/625 [11:17<06:38,  1.86s/it]

Batch 410, Loss: 0.3463
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  67%|██████▋   | 421/625 [11:31<04:38,  1.37s/it]

Batch 420, Loss: 0.5105
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  69%|██████▉   | 431/625 [11:47<04:28,  1.39s/it]

Batch 430, Loss: 0.3981
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  71%|███████   | 441/625 [12:03<04:28,  1.46s/it]

Batch 440, Loss: 0.4844
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  72%|███████▏  | 451/625 [12:20<05:08,  1.77s/it]

Batch 450, Loss: 0.6036
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  74%|███████▍  | 461/625 [12:37<04:40,  1.71s/it]

Batch 460, Loss: 0.3579
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  75%|███████▌  | 471/625 [12:54<04:37,  1.80s/it]

Batch 470, Loss: 0.2642
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  77%|███████▋  | 481/625 [13:08<03:17,  1.37s/it]

Batch 480, Loss: 0.4186
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  79%|███████▊  | 491/625 [13:24<03:28,  1.56s/it]

Batch 490, Loss: 0.5722
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  80%|████████  | 501/625 [13:39<02:51,  1.38s/it]

Batch 500, Loss: 0.4940
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  82%|████████▏ | 511/625 [13:55<03:05,  1.63s/it]

Batch 510, Loss: 0.3906
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  83%|████████▎ | 521/625 [14:11<02:45,  1.59s/it]

Batch 520, Loss: 0.3896
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  85%|████████▍ | 531/625 [14:26<02:08,  1.36s/it]

Batch 530, Loss: 0.3744
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  87%|████████▋ | 541/625 [14:40<01:54,  1.36s/it]

Batch 540, Loss: 0.3288
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  88%|████████▊ | 551/625 [14:56<01:44,  1.42s/it]

Batch 550, Loss: 0.3367
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  90%|████████▉ | 561/625 [15:10<01:32,  1.44s/it]

Batch 560, Loss: 0.4057
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  91%|█████████▏| 571/625 [15:28<01:58,  2.19s/it]

Batch 570, Loss: 0.5691
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  93%|█████████▎| 581/625 [15:44<01:07,  1.53s/it]

Batch 580, Loss: 0.4580
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  95%|█████████▍| 591/625 [16:01<00:56,  1.65s/it]

Batch 590, Loss: 0.4585
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  96%|█████████▌| 601/625 [16:15<00:33,  1.38s/it]

Batch 600, Loss: 0.4244
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  98%|█████████▊| 611/625 [16:31<00:19,  1.41s/it]

Batch 610, Loss: 0.4162
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1:  99%|█████████▉| 621/625 [16:48<00:06,  1.65s/it]

Batch 620, Loss: 0.3912
Output probs shape: torch.Size([1568, 46])
Target shape: torch.Size([1568])


Epoch 1/1: 100%|██████████| 625/625 [16:53<00:00,  1.62s/it]


Epoch 1/1, Average Loss: 0.4290


Validation:   0%|          | 0/157 [00:00<?, ?it/s]

Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:   1%|          | 1/157 [00:09<23:50,  9.17s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:   1%|▏         | 2/157 [00:10<12:30,  4.84s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:   2%|▏         | 3/157 [00:12<08:54,  3.47s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:   3%|▎         | 4/157 [00:14<06:39,  2.61s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:   3%|▎         | 5/157 [00:15<05:16,  2.08s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:   4%|▍         | 6/157 [00:16<04:34,  1.82s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:   4%|▍         | 7/157 [00:17<04:08,  1.66s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:   5%|▌         | 8/157 [00:19<03:42,  1.50s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:   6%|▌         | 9/157 [00:21<04:04,  1.65s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:   6%|▋         | 10/157 [00:22<03:46,  1.54s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:   7%|▋         | 11/157 [00:24<04:03,  1.67s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:   8%|▊         | 12/157 [00:26<04:23,  1.82s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:   8%|▊         | 13/157 [00:28<04:43,  1.97s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:   9%|▉         | 14/157 [00:30<04:45,  1.99s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  10%|▉         | 15/157 [00:32<04:43,  1.99s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  10%|█         | 16/157 [00:34<04:36,  1.96s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  11%|█         | 17/157 [00:35<04:04,  1.74s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  11%|█▏        | 18/157 [00:37<03:38,  1.57s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  12%|█▏        | 19/157 [00:38<03:17,  1.43s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  13%|█▎        | 20/157 [00:39<03:08,  1.38s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  13%|█▎        | 21/157 [00:40<03:00,  1.33s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  14%|█▍        | 22/157 [00:41<02:56,  1.31s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  15%|█▍        | 23/157 [00:43<02:53,  1.29s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  15%|█▌        | 24/157 [00:44<03:00,  1.36s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  16%|█▌        | 25/157 [00:47<04:10,  1.90s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  17%|█▋        | 26/157 [00:50<04:22,  2.01s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  17%|█▋        | 27/157 [00:52<04:24,  2.03s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  18%|█▊        | 28/157 [00:54<04:21,  2.03s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  18%|█▊        | 29/157 [00:56<04:15,  2.00s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  19%|█▉        | 30/157 [00:57<03:47,  1.79s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  20%|█▉        | 31/157 [00:58<03:25,  1.63s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  20%|██        | 32/157 [01:00<03:10,  1.52s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  21%|██        | 33/157 [01:01<03:02,  1.47s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  22%|██▏       | 34/157 [01:02<02:51,  1.39s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  22%|██▏       | 35/157 [01:03<02:50,  1.40s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  23%|██▎       | 36/157 [01:05<02:42,  1.34s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  24%|██▎       | 37/157 [01:06<02:53,  1.44s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  24%|██▍       | 38/157 [01:08<03:13,  1.63s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  25%|██▍       | 39/157 [01:11<03:30,  1.78s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  25%|██▌       | 40/157 [01:12<03:32,  1.82s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  26%|██▌       | 41/157 [01:15<04:06,  2.13s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  27%|██▋       | 42/157 [01:17<03:58,  2.07s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  27%|██▋       | 43/157 [01:19<03:34,  1.88s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  28%|██▊       | 44/157 [01:21<03:41,  1.96s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  29%|██▊       | 45/157 [01:23<03:45,  2.01s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  29%|██▉       | 46/157 [01:25<03:50,  2.07s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  30%|██▉       | 47/157 [01:27<03:54,  2.13s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  31%|███       | 48/157 [01:30<03:54,  2.15s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  31%|███       | 49/157 [01:33<04:45,  2.65s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  32%|███▏      | 50/157 [01:35<04:23,  2.46s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  32%|███▏      | 51/157 [01:37<04:02,  2.29s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  33%|███▎      | 52/157 [01:39<03:51,  2.20s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  34%|███▍      | 53/157 [01:42<03:51,  2.22s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  34%|███▍      | 54/157 [01:43<03:13,  1.88s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  35%|███▌      | 55/157 [01:44<02:51,  1.68s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  36%|███▌      | 56/157 [01:45<02:40,  1.59s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  36%|███▋      | 57/157 [01:47<02:31,  1.51s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  37%|███▋      | 58/157 [01:48<02:23,  1.45s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  38%|███▊      | 59/157 [01:49<02:13,  1.37s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  38%|███▊      | 60/157 [01:50<02:12,  1.36s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  39%|███▉      | 61/157 [01:52<02:07,  1.33s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  39%|███▉      | 62/157 [01:54<02:23,  1.51s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  40%|████      | 63/157 [01:56<02:39,  1.69s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  41%|████      | 64/157 [01:58<02:52,  1.85s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  41%|████▏     | 65/157 [02:00<02:55,  1.91s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  42%|████▏     | 66/157 [02:02<02:51,  1.89s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  43%|████▎     | 67/157 [02:03<02:36,  1.73s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  43%|████▎     | 68/157 [02:05<02:22,  1.60s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  44%|████▍     | 69/157 [02:06<02:17,  1.57s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  45%|████▍     | 70/157 [02:07<02:10,  1.50s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  45%|████▌     | 71/157 [02:09<02:03,  1.44s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  46%|████▌     | 72/157 [02:10<01:58,  1.39s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  46%|████▋     | 73/157 [02:11<01:47,  1.28s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  47%|████▋     | 74/157 [02:12<01:45,  1.28s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  48%|████▊     | 75/157 [02:14<01:59,  1.45s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  48%|████▊     | 76/157 [02:16<02:14,  1.66s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  49%|████▉     | 77/157 [02:18<02:24,  1.81s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  50%|████▉     | 78/157 [02:21<02:31,  1.92s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  50%|█████     | 79/157 [02:23<02:30,  1.93s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  51%|█████     | 80/157 [02:24<02:20,  1.82s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  52%|█████▏    | 81/157 [02:25<02:04,  1.64s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  52%|█████▏    | 82/157 [02:27<01:56,  1.55s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  53%|█████▎    | 83/157 [02:28<01:51,  1.51s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  54%|█████▎    | 84/157 [02:29<01:46,  1.46s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  54%|█████▍    | 85/157 [02:31<01:42,  1.42s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  55%|█████▍    | 86/157 [02:32<01:37,  1.37s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  55%|█████▌    | 87/157 [02:34<01:38,  1.41s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  56%|█████▌    | 88/157 [02:36<01:49,  1.58s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  57%|█████▋    | 89/157 [02:37<01:52,  1.65s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  57%|█████▋    | 90/157 [02:39<01:59,  1.79s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  58%|█████▊    | 91/157 [02:41<02:01,  1.83s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  59%|█████▊    | 92/157 [02:43<02:00,  1.85s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  59%|█████▉    | 93/157 [02:45<01:56,  1.82s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  60%|█████▉    | 94/157 [02:46<01:45,  1.67s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  61%|██████    | 95/157 [02:48<01:36,  1.56s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  61%|██████    | 96/157 [02:49<01:28,  1.45s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  62%|██████▏   | 97/157 [02:50<01:26,  1.43s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  62%|██████▏   | 98/157 [02:52<01:22,  1.39s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  63%|██████▎   | 99/157 [02:53<01:19,  1.37s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  64%|██████▎   | 100/157 [02:54<01:16,  1.34s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  64%|██████▍   | 101/157 [02:56<01:21,  1.45s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  65%|██████▍   | 102/157 [02:58<01:29,  1.63s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  66%|██████▌   | 103/157 [03:00<01:32,  1.71s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  66%|██████▌   | 104/157 [03:02<01:38,  1.85s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  67%|██████▋   | 105/157 [03:04<01:38,  1.90s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  68%|██████▊   | 106/157 [03:06<01:36,  1.90s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  68%|██████▊   | 107/157 [03:08<01:36,  1.92s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  69%|██████▉   | 108/157 [03:09<01:25,  1.75s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  69%|██████▉   | 109/157 [03:11<01:17,  1.62s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  70%|███████   | 110/157 [03:12<01:09,  1.49s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  71%|███████   | 111/157 [03:13<01:04,  1.40s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  71%|███████▏  | 112/157 [03:14<01:02,  1.39s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  72%|███████▏  | 113/157 [03:16<01:00,  1.37s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  73%|███████▎  | 114/157 [03:17<00:58,  1.37s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  73%|███████▎  | 115/157 [03:19<01:01,  1.47s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  74%|███████▍  | 116/157 [03:21<01:06,  1.62s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  75%|███████▍  | 117/157 [03:23<01:10,  1.77s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  75%|███████▌  | 118/157 [03:25<01:15,  1.93s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  76%|███████▌  | 119/157 [03:27<01:14,  1.97s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  76%|███████▋  | 120/157 [03:29<01:10,  1.92s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  77%|███████▋  | 121/157 [03:31<01:07,  1.87s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  78%|███████▊  | 122/157 [03:32<01:04,  1.84s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  78%|███████▊  | 123/157 [03:34<00:59,  1.74s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  79%|███████▉  | 124/157 [03:35<00:53,  1.63s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  80%|███████▉  | 125/157 [03:36<00:47,  1.48s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  80%|████████  | 126/157 [03:38<00:42,  1.38s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  81%|████████  | 127/157 [03:39<00:41,  1.38s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  82%|████████▏ | 128/157 [03:40<00:39,  1.37s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  82%|████████▏ | 129/157 [03:42<00:37,  1.36s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  83%|████████▎ | 130/157 [03:43<00:38,  1.43s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  83%|████████▎ | 131/157 [03:45<00:43,  1.66s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  84%|████████▍ | 132/157 [03:48<00:45,  1.81s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  85%|████████▍ | 133/157 [03:50<00:45,  1.91s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  85%|████████▌ | 134/157 [03:52<00:43,  1.91s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  86%|████████▌ | 135/157 [03:54<00:42,  1.92s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  87%|████████▋ | 136/157 [03:55<00:37,  1.81s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  87%|████████▋ | 137/157 [03:56<00:33,  1.66s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  88%|████████▊ | 138/157 [03:58<00:29,  1.54s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  89%|████████▊ | 139/157 [03:59<00:27,  1.50s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  89%|████████▉ | 140/157 [04:00<00:24,  1.44s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  90%|████████▉ | 141/157 [04:02<00:22,  1.39s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  90%|█████████ | 142/157 [04:03<00:20,  1.34s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  91%|█████████ | 143/157 [04:04<00:18,  1.30s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  92%|█████████▏| 144/157 [04:06<00:19,  1.47s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  92%|█████████▏| 145/157 [04:08<00:20,  1.68s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  93%|█████████▎| 146/157 [04:10<00:20,  1.84s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  94%|█████████▎| 147/157 [04:12<00:19,  1.92s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  94%|█████████▍| 148/157 [04:15<00:18,  2.02s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  95%|█████████▍| 149/157 [04:17<00:16,  2.05s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  96%|█████████▌| 150/157 [04:19<00:15,  2.16s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  96%|█████████▌| 151/157 [04:22<00:13,  2.21s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  97%|█████████▋| 152/157 [04:23<00:10,  2.02s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  97%|█████████▋| 153/157 [04:24<00:06,  1.68s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  98%|█████████▊| 154/157 [04:25<00:04,  1.44s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  99%|█████████▊| 155/157 [04:26<00:02,  1.27s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([32, 3, 224, 224])
entity_name: torch.Size([32])
group_id: torch.Size([32])
target: torch.Size([32, 50])


Validation:  99%|█████████▉| 156/157 [04:27<00:01,  1.15s/it]

output_probs: torch.Size([32, 50, 46])
Reshaped output_probs: torch.Size([1600, 46])
Reshaped target: torch.Size([1600])
Validation batch shapes:
img: torch.Size([8, 3, 224, 224])
entity_name: torch.Size([8])
group_id: torch.Size([8])
target: torch.Size([8, 50])


Validation: 100%|██████████| 157/157 [04:27<00:00,  1.70s/it]

output_probs: torch.Size([8, 50, 46])
Reshaped output_probs: torch.Size([400, 46])
Reshaped target: torch.Size([400])
Validation Loss: 7.8482, Accuracy: 0.0421





## Save Model

In [28]:
save_path = "model_v1_bs25k_no_1.pth"

In [29]:
torch.save(model.state_dict(), save_path)
print(f"Model saved to {save_path}")

Model saved to model_v1_bs25k_no_1.pth


## Evaluation and dumping the Model

In [30]:
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import json

def save_model_and_calculate_metrics(model, test_loader, tokenizer, device, save_path='trained_model.pth'):
    # Save the model
    torch.save(model.state_dict(), save_path)
    print(f"Model saved to {save_path}")

    # Set model to evaluation mode
    model.eval()

    all_predictions = []
    all_targets = []

    # Disable gradient calculations
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating"):
            images, entity_names, group_ids, targets = batch
            images = images.to(device)
            entity_names = entity_names.to(device)
            group_ids = group_ids.to(device)

            # Generate predictions
            outputs = model.generate(images, entity_names, group_ids)

            # Convert predictions and targets to text
            pred_texts = [tokenizer.decode(pred) for pred in outputs.cpu().numpy()]
            target_texts = [tokenizer.decode(target) for target in targets.cpu().numpy()]

            all_predictions.extend(pred_texts)
            all_targets.extend(target_texts)

    # Calculate metrics
    accuracy = accuracy_score(all_targets, all_predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(all_targets, all_predictions, average='weighted')

    # Print metrics
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")

    # Save metrics to a file
    metrics = {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1_score": f1
    }
    with open('metrics.json', 'w') as f:
        json.dump(metrics, f, indent=4)
    print("Metrics saved to metrics.json")

In [31]:
 # Load your test dataset
test_df = pd.read_csv('https://raw.githubusercontent.com/Dhyanesh-Panchal/amazon_ml_2024_deepinsight/refs/heads/master/student_resource%203/dataset/test.csv')
test_df = test_df.sample(frac=0.1, random_state=42)

In [32]:
test_df

Unnamed: 0,index,image_link,group_id,entity_name
81509,81578,https://m.media-amazon.com/images/I/610apmjypy...,311997,height
120375,120463,https://m.media-amazon.com/images/I/71WGncjrqQ...,355666,item_weight
89356,89428,https://m.media-amazon.com/images/I/6158kNWFcS...,407808,item_weight
53381,53428,https://m.media-amazon.com/images/I/51Zj1+VgNT...,267482,item_weight
40962,41007,https://m.media-amazon.com/images/I/51M9bDawya...,245652,width
...,...,...,...,...
23476,23506,https://m.media-amazon.com/images/I/5131GDDywM...,242256,depth
68929,68990,https://m.media-amazon.com/images/I/51qjw2sj1o...,683885,height
94818,94892,https://m.media-amazon.com/images/I/618MHecxNz...,656960,depth
5374,5384,https://m.media-amazon.com/images/I/41CtkPNJYs...,709627,voltage
