In [1]:
import pandas as pd
import os
import matplotlib.pyplot as plt
from PIL import Image

In [3]:
train_csv_path = '/kaggle/input/visual-taxonomy/train.csv'
test_csv_path = '/kaggle/input/visual-taxonomy/test.csv'
category_attributes_path = '/kaggle/input/visual-taxonomy/category_attributes.parquet'
sample_submission_path = '/kaggle/input/visual-taxonomy/sample_submission.csv'
images_folder = '/kaggle/input/visual-taxonomy/train_images'

In [4]:
print("\nLoading train.csv ...")
train_df = pd.read_csv(train_csv_path)
print(f"train.csv shape: {train_df.shape}")
print(f"train.csv columns: {train_df.columns}")

train_df.head()


Loading train.csv ...
train.csv shape: (70213, 13)
train.csv columns: Index(['id', 'Category', 'len', 'attr_1', 'attr_2', 'attr_3', 'attr_4',
       'attr_5', 'attr_6', 'attr_7', 'attr_8', 'attr_9', 'attr_10'],
      dtype='object')


Unnamed: 0,id,Category,len,attr_1,attr_2,attr_3,attr_4,attr_5,attr_6,attr_7,attr_8,attr_9,attr_10
0,0,Men Tshirts,5,default,round,printed,default,short sleeves,,,,,
1,1,Men Tshirts,5,multicolor,polo,solid,solid,short sleeves,,,,,
2,2,Men Tshirts,5,default,polo,solid,solid,short sleeves,,,,,
3,3,Men Tshirts,5,multicolor,polo,solid,solid,short sleeves,,,,,
4,4,Men Tshirts,5,multicolor,polo,solid,solid,short sleeves,,,,,


In [4]:
print("\nLoading test.csv ...")
test_df = pd.read_csv(test_csv_path)
print(f"test.csv shape: {test_df.shape}")
print(f"test.csv columns: {test_df.columns}")

test_df.head()



Loading test.csv ...
test.csv shape: (30205, 2)
test.csv columns: Index(['id', 'Category'], dtype='object')


Unnamed: 0,id,Category
0,0,Men Tshirts
1,1,Men Tshirts
2,2,Men Tshirts
3,3,Men Tshirts
4,4,Men Tshirts


In [18]:
category_dfs = {}

for category in train_df['Category'].unique():
    category_dfs[category] = train_df[train_df['Category'] == category]

men_tshirts_df = category_dfs['Men Tshirts']
print(f"Men Tshirts DataFrame:\n{men_tshirts_df.info()}")
print(men_tshirts_df.head())
print(len(men_tshirts_df))

sarees_df = category_dfs['Sarees']
print(f"Sarees DataFrame:\n{sarees_df.info()}")
print(sarees_df.head())
print(len(sarees_df))

kurtis_df = category_dfs['Kurtis']
print(f"Kurtis DataFrame:\n{kurtis_df.info()}")
print(kurtis_df.head())
print(len(kurtis_df))

women_tshirts_df = category_dfs['Women Tshirts']
print(f"Women Tshirts:\n{women_tshirts_df.info()}")
print(women_tshirts_df.head())
print(len(women_tshirts_df))

women_tops_df = category_dfs['Women Tops & Tunics']
print(f"Women Tops and Tunics DataFrame:\n{women_tops_df.info()}")
print(women_tshirts_df.head())
print(len(women_tshirts_df))

<class 'pandas.core.frame.DataFrame'>
Index: 7267 entries, 0 to 7266
Data columns (total 13 columns):
 #   Column    Non-Null Count  Dtype 
---  ------    --------------  ----- 
 0   id        7267 non-null   int64 
 1   Category  7267 non-null   object
 2   len       7267 non-null   int64 
 3   attr_1    6010 non-null   object
 4   attr_2    6144 non-null   object
 5   attr_3    5791 non-null   object
 6   attr_4    5949 non-null   object
 7   attr_5    5977 non-null   object
 8   attr_6    0 non-null      object
 9   attr_7    0 non-null      object
 10  attr_8    0 non-null      object
 11  attr_9    0 non-null      object
 12  attr_10   0 non-null      object
dtypes: int64(2), object(11)
memory usage: 794.8+ KB
Men Tshirts DataFrame:
None
   id     Category  len      attr_1 attr_2   attr_3   attr_4         attr_5  \
0   0  Men Tshirts    5     default  round  printed  default  short sleeves   
1   1  Men Tshirts    5  multicolor   polo    solid    solid  short sleeves   
2   2  Men

In [6]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import ViTModel, ViTImageProcessor
from PIL import Image
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from torch.cuda.amp import GradScaler, autocast
import torchvision.transforms as T

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [7]:
# Preprocessing function for images
def preprocess_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, [224, 224])
    return tf.keras.applications.resnet.preprocess_input(image)

In [None]:
# Prepare and process data (assuming sarees_df is pre-loaded)
sarees_df = sarees_df.drop(columns=['Category', 'len'])
sarees_df = sarees_df.apply(
    lambda col: col.fillna('dummy_value') if col.isna().all() 
    else col.fillna(col.mode()[0]) if col.dtype == 'object' or col.dtype.name == 'category' 
    else col
)

# Define attribute columns and category mappings
attribute_columns_sarees = [f'attr_{i}' for i in range(1, 11)]
category_mappings_sarees = {}

for col in attribute_columns_sarees:
    sarees_df[col] = pd.Categorical(sarees_df[col])
    category_mappings_sarees[col] = dict(enumerate(sarees_df[col].cat.categories))
    sarees_df[col] = sarees_df[col].cat.codes

sarees_df['image_path'] = '/kaggle/input/visual-taxonomy/train_images/' + sarees_df['id'].apply(lambda x: f"{x:06d}.jpg")

# Split dataset into train, validation, and test sets
train_df_sarees, temp_df = train_test_split(sarees_df, test_size=0.3, random_state=42)
val_df_sarees, test_df_sarees = train_test_split(temp_df, test_size=0.5, random_state=42)

# Define number of classes for each attribute
num_classes_per_attribute = {
    'attr_1': 4,  
    'attr_2': 6,
    'attr_3': 3,
    'attr_4': 8,
    'attr_5': 4,
    'attr_6': 3,
    'attr_7': 4,
    'attr_8': 5,
    'attr_9': 9,
    'attr_10': 2
}

class MultiOutputModel(nn.Module):
    def __init__(self, base_model, num_classes_per_attribute):
        super(MultiOutputModel, self).__init__()
        self.base_model = base_model

        self.dropout = nn.Dropout(0.3)  
        self.heads = nn.ModuleDict({
            attr: nn.Linear(base_model.config.hidden_size, num_classes)
            for attr, num_classes in num_classes_per_attribute.items()
        })



    def forward(self, pixel_values):
        outputs = self.base_model(pixel_values=pixel_values)
        hidden_states = outputs.last_hidden_state
        cls_hidden_states = hidden_states[:, 0, :]
        cls_hidden_states = self.dropout(cls_hidden_states)  
        logits = {attr: self.heads[attr](cls_hidden_states) for attr in self.heads}
        return logits

base_model = ViTModel.from_pretrained("google/vit-base-patch16-224")
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
model_vit_sarees = MultiOutputModel(base_model, num_classes_per_attribute)


# Data Augmentation

data_augmentation = T.Compose([
    T.RandomResizedCrop(224),
    T.RandomHorizontalFlip(),
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    T.ToTensor(),
    T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])


# Dataset and DataLoader

class CustomDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        id_str = str(self.df.iloc[idx]["id"]).zfill(6)
        path = self.df.iloc[idx]["image_path"]
        image = Image.open(path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        labels = {}
        
        for attr in attribute_columns_sarees:
            labels[attr] = torch.tensor(self.df.iloc[idx][attr], dtype=torch.long)

        return image, labels

# Initialize datasets and data loaders
train_dataset = CustomDataset(train_df_sarees, transform=data_augmentation)
val_dataset = CustomDataset(val_df_sarees, transform=T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
]))
test_dataset = CustomDataset(test_df_sarees, transform=T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
]))

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)  # Increased batch size
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_vit_sarees.to(device)

# Optimizer with Weight Decay

optimizer = torch.optim.AdamW(model_vit_sarees.parameters(), lr=5e-5, weight_decay=0.01)


# Learning Rate Scheduler: Cosine Annealing

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)  # Cosine annealing

scaler = GradScaler()

# Training Loop 

num_epochs = 10  

for epoch in range(num_epochs):
    model_vit_sarees.train()
    running_corrects = {attr: 0 for attr in attribute_columns_sarees}
    total_samples = {attr: 0 for attr in attribute_columns_sarees}

    for images, labels in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}/{num_epochs}", leave=False):
        images = images.to(device)
        optimizer.zero_grad()

        with autocast():
            outputs = model_vit_sarees(pixel_values=images)
            
            losses = {}
            
            for attr in labels.keys():
                loss_fn = nn.CrossEntropyLoss()
                target = labels[attr].long().to(device)
                loss = loss_fn(outputs[attr], target)
                losses[attr] = loss

                # Compute accuracy
                
                _, preds = torch.max(outputs[attr], 1)
                running_corrects[attr] += torch.sum(preds == target.data).item()
                total_samples[attr] += target.size(0)

            total_loss = sum(losses.values())

        scaler.scale(total_loss).backward()
        scaler.step(optimizer)
        scaler.update()

    scheduler.step()



    # Validation Accuracy

    model_vit_sarees.eval()
    val_corrects = {attr: 0 for attr in attribute_columns_sarees}
    val_samples = {attr: 0 for attr in attribute_columns_sarees}

    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc="Validating", leave=False):
            images = images.to(device)
            outputs = model_vit_sarees(pixel_values=images)

            for attr in labels.keys():
                loss_fn = nn.CrossEntropyLoss()
                target = labels[attr].long().to(device)

                loss = loss_fn(outputs[attr], target)
                _, preds = torch.max(outputs[attr], 1)
                val_corrects[attr] += torch.sum(preds == target.data).item()
                val_samples[attr] += target.size(0)

    for attr in attribute_columns_sarees:
        val_acc = val_corrects[attr] / val_samples[attr] if val_samples[attr] > 0 else 0
        print(f"Epoch {epoch + 1}, Attribute {attr}, Validation Accuracy: {val_acc:.4f}")

# Final test accuracy
model_vit_sarees.eval()
test_corrects = {attr: 0 for attr in attribute_columns_sarees}
test_samples = {attr: 0 for attr in attribute_columns_sarees}

with torch.no_grad():
    for images, labels in tqdm(test_loader, desc="Testing", leave=False):
        images = images.to(device)
        outputs = model_vit_sarees(pixel_values=images)
        
        for attr in labels.keys():
            target = labels[attr].long().to(device)
            
            # Compute test accuracy
            _, preds = torch.max(outputs[attr], 1)
            test_corrects[attr] += torch.sum(preds == target.data).item()
            test_samples[attr] += target.size(0)

# Print final test accuracy for each attribute
for attr in attribute_columns_sarees:
    test_acc = test_corrects[attr] / test_samples[attr] if test_samples[attr] > 0 else 0
    print(f"Final Test Accuracy for Attribute {attr}: {test_acc:.4f}")

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  scaler = GradScaler()
  self.pid = os.fork()
  with autocast():
  self.pid = os.fork()
                                                                      

Epoch 1, Attribute attr_1, Validation Accuracy: 0.8648
Epoch 1, Attribute attr_2, Validation Accuracy: 0.6770
Epoch 1, Attribute attr_3, Validation Accuracy: 0.8438
Epoch 1, Attribute attr_4, Validation Accuracy: 0.5509
Epoch 1, Attribute attr_5, Validation Accuracy: 0.6962
Epoch 1, Attribute attr_6, Validation Accuracy: 0.9564
Epoch 1, Attribute attr_7, Validation Accuracy: 0.7224
Epoch 1, Attribute attr_8, Validation Accuracy: 0.8630
Epoch 1, Attribute attr_9, Validation Accuracy: 0.6472
Epoch 1, Attribute attr_10, Validation Accuracy: 0.8343


                                                                      

Epoch 2, Attribute attr_1, Validation Accuracy: 0.8594
Epoch 2, Attribute attr_2, Validation Accuracy: 0.6733
Epoch 2, Attribute attr_3, Validation Accuracy: 0.8448
Epoch 2, Attribute attr_4, Validation Accuracy: 0.5523
Epoch 2, Attribute attr_5, Validation Accuracy: 0.7078
Epoch 2, Attribute attr_6, Validation Accuracy: 0.9578
Epoch 2, Attribute attr_7, Validation Accuracy: 0.7260
Epoch 2, Attribute attr_8, Validation Accuracy: 0.8532
Epoch 2, Attribute attr_9, Validation Accuracy: 0.6395
Epoch 2, Attribute attr_10, Validation Accuracy: 0.8339


                                                                      

Epoch 3, Attribute attr_1, Validation Accuracy: 0.8656
Epoch 3, Attribute attr_2, Validation Accuracy: 0.6831
Epoch 3, Attribute attr_3, Validation Accuracy: 0.8470
Epoch 3, Attribute attr_4, Validation Accuracy: 0.5436
Epoch 3, Attribute attr_5, Validation Accuracy: 0.7035
Epoch 3, Attribute attr_6, Validation Accuracy: 0.9549
Epoch 3, Attribute attr_7, Validation Accuracy: 0.7384
Epoch 3, Attribute attr_8, Validation Accuracy: 0.8674
Epoch 3, Attribute attr_9, Validation Accuracy: 0.6541
Epoch 3, Attribute attr_10, Validation Accuracy: 0.8343


                                                                      

Epoch 4, Attribute attr_1, Validation Accuracy: 0.8637
Epoch 4, Attribute attr_2, Validation Accuracy: 0.6897
Epoch 4, Attribute attr_3, Validation Accuracy: 0.8361
Epoch 4, Attribute attr_4, Validation Accuracy: 0.5632
Epoch 4, Attribute attr_5, Validation Accuracy: 0.7071
Epoch 4, Attribute attr_6, Validation Accuracy: 0.9589
Epoch 4, Attribute attr_7, Validation Accuracy: 0.7387
Epoch 4, Attribute attr_8, Validation Accuracy: 0.8619
Epoch 4, Attribute attr_9, Validation Accuracy: 0.6544
Epoch 4, Attribute attr_10, Validation Accuracy: 0.8350


                                                                      

Epoch 5, Attribute attr_1, Validation Accuracy: 0.8641
Epoch 5, Attribute attr_2, Validation Accuracy: 0.6857
Epoch 5, Attribute attr_3, Validation Accuracy: 0.8496
Epoch 5, Attribute attr_4, Validation Accuracy: 0.5552
Epoch 5, Attribute attr_5, Validation Accuracy: 0.7133
Epoch 5, Attribute attr_6, Validation Accuracy: 0.9568
Epoch 5, Attribute attr_7, Validation Accuracy: 0.7402
Epoch 5, Attribute attr_8, Validation Accuracy: 0.8648
Epoch 5, Attribute attr_9, Validation Accuracy: 0.6548
Epoch 5, Attribute attr_10, Validation Accuracy: 0.8358


                                                                      

Epoch 6, Attribute attr_1, Validation Accuracy: 0.8634
Epoch 6, Attribute attr_2, Validation Accuracy: 0.6897
Epoch 6, Attribute attr_3, Validation Accuracy: 0.8467
Epoch 6, Attribute attr_4, Validation Accuracy: 0.5531
Epoch 6, Attribute attr_5, Validation Accuracy: 0.7148
Epoch 6, Attribute attr_6, Validation Accuracy: 0.9575
Epoch 6, Attribute attr_7, Validation Accuracy: 0.7253
Epoch 6, Attribute attr_8, Validation Accuracy: 0.8557
Epoch 6, Attribute attr_9, Validation Accuracy: 0.6537
Epoch 6, Attribute attr_10, Validation Accuracy: 0.8354


                                                                      

Epoch 7, Attribute attr_1, Validation Accuracy: 0.8656
Epoch 7, Attribute attr_2, Validation Accuracy: 0.6926
Epoch 7, Attribute attr_3, Validation Accuracy: 0.8474
Epoch 7, Attribute attr_4, Validation Accuracy: 0.5600
Epoch 7, Attribute attr_5, Validation Accuracy: 0.7184
Epoch 7, Attribute attr_6, Validation Accuracy: 0.9586
Epoch 7, Attribute attr_7, Validation Accuracy: 0.7449
Epoch 7, Attribute attr_8, Validation Accuracy: 0.8601
Epoch 7, Attribute attr_9, Validation Accuracy: 0.6621
Epoch 7, Attribute attr_10, Validation Accuracy: 0.8347


                                                                      

Epoch 8, Attribute attr_1, Validation Accuracy: 0.8659
Epoch 8, Attribute attr_2, Validation Accuracy: 0.6919
Epoch 8, Attribute attr_3, Validation Accuracy: 0.8474
Epoch 8, Attribute attr_4, Validation Accuracy: 0.5636
Epoch 8, Attribute attr_5, Validation Accuracy: 0.7177
Epoch 8, Attribute attr_6, Validation Accuracy: 0.9593
Epoch 8, Attribute attr_7, Validation Accuracy: 0.7435
Epoch 8, Attribute attr_8, Validation Accuracy: 0.8685
Epoch 8, Attribute attr_9, Validation Accuracy: 0.6613
Epoch 8, Attribute attr_10, Validation Accuracy: 0.8354


                                                                      

Epoch 9, Attribute attr_1, Validation Accuracy: 0.8652
Epoch 9, Attribute attr_2, Validation Accuracy: 0.6933
Epoch 9, Attribute attr_3, Validation Accuracy: 0.8470
Epoch 9, Attribute attr_4, Validation Accuracy: 0.5625
Epoch 9, Attribute attr_5, Validation Accuracy: 0.7195
Epoch 9, Attribute attr_6, Validation Accuracy: 0.9593
Epoch 9, Attribute attr_7, Validation Accuracy: 0.7402
Epoch 9, Attribute attr_8, Validation Accuracy: 0.8685
Epoch 9, Attribute attr_9, Validation Accuracy: 0.6592
Epoch 9, Attribute attr_10, Validation Accuracy: 0.8343


                                                                       

Epoch 10, Attribute attr_1, Validation Accuracy: 0.8659
Epoch 10, Attribute attr_2, Validation Accuracy: 0.6900
Epoch 10, Attribute attr_3, Validation Accuracy: 0.8470
Epoch 10, Attribute attr_4, Validation Accuracy: 0.5621
Epoch 10, Attribute attr_5, Validation Accuracy: 0.7191
Epoch 10, Attribute attr_6, Validation Accuracy: 0.9589
Epoch 10, Attribute attr_7, Validation Accuracy: 0.7424
Epoch 10, Attribute attr_8, Validation Accuracy: 0.8652
Epoch 10, Attribute attr_9, Validation Accuracy: 0.6602
Epoch 10, Attribute attr_10, Validation Accuracy: 0.8347


                                                        

Final Test Accuracy for Attribute attr_1: 0.8641
Final Test Accuracy for Attribute attr_2: 0.6959
Final Test Accuracy for Attribute attr_3: 0.8398
Final Test Accuracy for Attribute attr_4: 0.5549
Final Test Accuracy for Attribute attr_5: 0.7009
Final Test Accuracy for Attribute attr_6: 0.9611
Final Test Accuracy for Attribute attr_7: 0.7504
Final Test Accuracy for Attribute attr_8: 0.8634
Final Test Accuracy for Attribute attr_9: 0.6486
Final Test Accuracy for Attribute attr_10: 0.8307




In [None]:
# Prepare and process data 
kurtis_df = kurtis_df.drop(columns=['Category', 'len'])
kurtis_df = kurtis_df.apply(
    lambda col: col.fillna('dummy_value') if col.isna().all() 
    else col.fillna(col.mode()[0]) if col.dtype == 'object' or col.dtype.name == 'category' 
    else col
)

attribute_columns_kurtis = [f'attr_{i}' for i in range(1, 11)]
category_mappings_kurtis = {}

for col in attribute_columns_kurtis:
    kurtis_df[col] = pd.Categorical(kurtis_df[col])
    category_mappings_kurtis[col] = dict(enumerate(kurtis_df[col].cat.categories))
    kurtis_df[col] = kurtis_df[col].cat.codes

kurtis_df['image_path'] = '/kaggle/input/visual-taxonomy/train_images/' + kurtis_df['id'].apply(lambda x: f"{x:06d}.jpg")

train_df_kurtis, temp_df_kurtis = train_test_split(kurtis_df, test_size=0.3, random_state=42)
val_df_kurtis, test_df_kurtis = train_test_split(temp_df_kurtis, test_size=0.5, random_state=42)

num_classes_per_attribute_kurtis = {
    'attr_1': 13,
    'attr_2': 2,
    'attr_3': 2,
    'attr_4': 2,
    'attr_5': 2,
    'attr_6': 2,
    'attr_7': 2,
    'attr_8': 3,
    'attr_9': 2,
    'attr_10': 1
}

class MultiOutputModel(nn.Module):
    def __init__(self, base_model, num_classes_per_attribute):
        super(MultiOutputModel, self).__init__()
        self.base_model = base_model

        self.dropout = nn.Dropout(0.3)  
        self.heads = nn.ModuleDict({
            attr: nn.Linear(base_model.config.hidden_size, num_classes)
            for attr, num_classes in num_classes_per_attribute.items()
        })

    def forward(self, pixel_values):
        outputs = self.base_model(pixel_values=pixel_values)
        hidden_states = outputs.last_hidden_state
        cls_hidden_states = hidden_states[:, 0, :]
        cls_hidden_states = self.dropout(cls_hidden_states) 
        logits = {attr: self.heads[attr](cls_hidden_states) for attr in self.heads}
        return logits

base_model = ViTModel.from_pretrained("google/vit-base-patch16-224")
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
model_vit_kurtis = MultiOutputModel(base_model, num_classes_per_attribute_kurtis)

# Data Augmentation
data_augmentation = T.Compose([
    T.RandomResizedCrop(224),
    T.RandomHorizontalFlip(),
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    T.ToTensor(),
    T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

class CustomDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        id_str = str(self.df.iloc[idx]["id"]).zfill(6)
        path = self.df.iloc[idx]["image_path"]
        image = Image.open(path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        labels = {}
        for attr in attribute_columns_kurtis:
            labels[attr] = torch.tensor(self.df.iloc[idx][attr], dtype=torch.long)

        return image, labels

# Initialize datasets and data loaders
train_dataset_kurtis = CustomDataset(train_df_kurtis, transform=data_augmentation)
val_dataset_kurtis = CustomDataset(val_df_kurtis, transform=T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
]))
test_dataset_kurtis = CustomDataset(test_df_kurtis, transform=T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
]))

train_loader_kurtis = DataLoader(train_dataset_kurtis, batch_size=32, shuffle=True, num_workers=4)
val_loader_kurtis = DataLoader(val_dataset_kurtis, batch_size=32, shuffle=False, num_workers=4)
test_loader_kurtis = DataLoader(test_dataset_kurtis, batch_size=32, shuffle=False, num_workers=4)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_vit_kurtis.to(device)

# Optimizer with Weight Decay
optimizer = torch.optim.AdamW(model_vit_kurtis.parameters(), lr=5e-5, weight_decay=0.01)

# Learning Rate Scheduler: Cosine Annealing
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

scaler = GradScaler()

# Training Loop 
num_epochs = 10

for epoch in range(num_epochs):
    model_vit_kurtis.train()
    running_corrects = {attr: 0 for attr in attribute_columns_kurtis}
    total_samples = {attr: 0 for attr in attribute_columns_kurtis}

    for images, labels in tqdm(train_loader_kurtis, desc=f"Training Epoch {epoch + 1}/{num_epochs}", leave=False):
        images = images.to(device)
        optimizer.zero_grad()

        with autocast():
            outputs = model_vit_kurtis(pixel_values=images)
            losses = {}
            for attr in labels.keys():
                loss_fn = nn.CrossEntropyLoss()
                target = labels[attr].long().to(device)
                loss = loss_fn(outputs[attr], target)
                losses[attr] = loss

                # Compute accuracy
                _, preds = torch.max(outputs[attr], 1)
                running_corrects[attr] += torch.sum(preds == target.data).item()
                total_samples[attr] += target.size(0)

            total_loss = sum(losses.values())

        scaler.scale(total_loss).backward()
        scaler.step(optimizer)
        scaler.update()

    scheduler.step()

    # Validation Accuracy
    model_vit_kurtis.eval()
    val_corrects = {attr: 0 for attr in attribute_columns_kurtis}
    val_samples = {attr: 0 for attr in attribute_columns_kurtis}

    with torch.no_grad():
        for images, labels in tqdm(val_loader_kurtis, desc="Validating", leave=False):
            images = images.to(device)
            outputs = model_vit_kurtis(pixel_values=images)
            for attr in labels.keys():
                target = labels[attr].long().to(device)
                _, preds = torch.max(outputs[attr], 1)
                val_corrects[attr] += torch.sum(preds == target.data).item()
                val_samples[attr] += target.size(0)

    for attr in attribute_columns_kurtis:
        val_acc = val_corrects[attr] / val_samples[attr] if val_samples[attr] > 0 else 0
        print(f"Epoch {epoch + 1}, Attribute {attr}, Validation Accuracy: {val_acc:.4f}")

# Final test accuracy
model_vit_kurtis.eval()
test_corrects = {attr: 0 for attr in attribute_columns_kurtis}
test_samples = {attr: 0 for attr in attribute_columns_kurtis}

with torch.no_grad():
    for images, labels in tqdm(test_loader_kurtis, desc="Testing", leave=False):
        images = images.to(device)
        outputs = model_vit_kurtis(pixel_values=images)
        
        for attr in labels.keys():
            target = labels[attr].long().to(device)
            _, preds = torch.max(outputs[attr], 1)
            test_corrects[attr] += torch.sum(preds == target.data).item()
            test_samples[attr] += target.size(0)

for attr in attribute_columns_kurtis:
    test_acc = test_corrects[attr] / test_samples[attr] if test_samples[attr] > 0 else 0
    print(f"Final Test Accuracy for Attribute {attr}: {test_acc:.4f}")


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  scaler = GradScaler()
  self.pid = os.fork()
  with autocast():
  self.pid = os.fork()
                                                                      

Epoch 1, Attribute attr_1, Validation Accuracy: 0.6862
Epoch 1, Attribute attr_2, Validation Accuracy: 0.8817
Epoch 1, Attribute attr_3, Validation Accuracy: 0.8104
Epoch 1, Attribute attr_4, Validation Accuracy: 0.9355
Epoch 1, Attribute attr_5, Validation Accuracy: 0.9110
Epoch 1, Attribute attr_6, Validation Accuracy: 0.7370
Epoch 1, Attribute attr_7, Validation Accuracy: 0.7410
Epoch 1, Attribute attr_8, Validation Accuracy: 0.9179
Epoch 1, Attribute attr_9, Validation Accuracy: 0.9941
Epoch 1, Attribute attr_10, Validation Accuracy: 1.0000


                                                                      

Epoch 2, Attribute attr_1, Validation Accuracy: 0.7546
Epoch 2, Attribute attr_2, Validation Accuracy: 0.8768
Epoch 2, Attribute attr_3, Validation Accuracy: 0.8113
Epoch 2, Attribute attr_4, Validation Accuracy: 0.9472
Epoch 2, Attribute attr_5, Validation Accuracy: 0.9130
Epoch 2, Attribute attr_6, Validation Accuracy: 0.7664
Epoch 2, Attribute attr_7, Validation Accuracy: 0.7419
Epoch 2, Attribute attr_8, Validation Accuracy: 0.9374
Epoch 2, Attribute attr_9, Validation Accuracy: 0.9971
Epoch 2, Attribute attr_10, Validation Accuracy: 1.0000


                                                                      

Epoch 3, Attribute attr_1, Validation Accuracy: 0.7830
Epoch 3, Attribute attr_2, Validation Accuracy: 0.8817
Epoch 3, Attribute attr_3, Validation Accuracy: 0.8192
Epoch 3, Attribute attr_4, Validation Accuracy: 0.9433
Epoch 3, Attribute attr_5, Validation Accuracy: 0.9198
Epoch 3, Attribute attr_6, Validation Accuracy: 0.7742
Epoch 3, Attribute attr_7, Validation Accuracy: 0.7488
Epoch 3, Attribute attr_8, Validation Accuracy: 0.9570
Epoch 3, Attribute attr_9, Validation Accuracy: 0.9990
Epoch 3, Attribute attr_10, Validation Accuracy: 1.0000


                                                                      

Epoch 4, Attribute attr_1, Validation Accuracy: 0.7898
Epoch 4, Attribute attr_2, Validation Accuracy: 0.8895
Epoch 4, Attribute attr_3, Validation Accuracy: 0.8240
Epoch 4, Attribute attr_4, Validation Accuracy: 0.9345
Epoch 4, Attribute attr_5, Validation Accuracy: 0.9247
Epoch 4, Attribute attr_6, Validation Accuracy: 0.7781
Epoch 4, Attribute attr_7, Validation Accuracy: 0.7556
Epoch 4, Attribute attr_8, Validation Accuracy: 0.9492
Epoch 4, Attribute attr_9, Validation Accuracy: 0.9980
Epoch 4, Attribute attr_10, Validation Accuracy: 1.0000


                                                                      

Epoch 5, Attribute attr_1, Validation Accuracy: 0.7996
Epoch 5, Attribute attr_2, Validation Accuracy: 0.8886
Epoch 5, Attribute attr_3, Validation Accuracy: 0.8182
Epoch 5, Attribute attr_4, Validation Accuracy: 0.9433
Epoch 5, Attribute attr_5, Validation Accuracy: 0.9218
Epoch 5, Attribute attr_6, Validation Accuracy: 0.7771
Epoch 5, Attribute attr_7, Validation Accuracy: 0.7546
Epoch 5, Attribute attr_8, Validation Accuracy: 0.9589
Epoch 5, Attribute attr_9, Validation Accuracy: 0.9980
Epoch 5, Attribute attr_10, Validation Accuracy: 1.0000


                                                                      

Epoch 6, Attribute attr_1, Validation Accuracy: 0.8035
Epoch 6, Attribute attr_2, Validation Accuracy: 0.8954
Epoch 6, Attribute attr_3, Validation Accuracy: 0.8289
Epoch 6, Attribute attr_4, Validation Accuracy: 0.9423
Epoch 6, Attribute attr_5, Validation Accuracy: 0.9189
Epoch 6, Attribute attr_6, Validation Accuracy: 0.7732
Epoch 6, Attribute attr_7, Validation Accuracy: 0.7478
Epoch 6, Attribute attr_8, Validation Accuracy: 0.9609
Epoch 6, Attribute attr_9, Validation Accuracy: 0.9980
Epoch 6, Attribute attr_10, Validation Accuracy: 1.0000


                                                                      

Epoch 7, Attribute attr_1, Validation Accuracy: 0.8133
Epoch 7, Attribute attr_2, Validation Accuracy: 0.8954
Epoch 7, Attribute attr_3, Validation Accuracy: 0.8260
Epoch 7, Attribute attr_4, Validation Accuracy: 0.9433
Epoch 7, Attribute attr_5, Validation Accuracy: 0.9228
Epoch 7, Attribute attr_6, Validation Accuracy: 0.7859
Epoch 7, Attribute attr_7, Validation Accuracy: 0.7654
Epoch 7, Attribute attr_8, Validation Accuracy: 0.9629
Epoch 7, Attribute attr_9, Validation Accuracy: 0.9980
Epoch 7, Attribute attr_10, Validation Accuracy: 1.0000


                                                                      

Epoch 8, Attribute attr_1, Validation Accuracy: 0.8250
Epoch 8, Attribute attr_2, Validation Accuracy: 0.8954
Epoch 8, Attribute attr_3, Validation Accuracy: 0.8133
Epoch 8, Attribute attr_4, Validation Accuracy: 0.9433
Epoch 8, Attribute attr_5, Validation Accuracy: 0.9208
Epoch 8, Attribute attr_6, Validation Accuracy: 0.7889
Epoch 8, Attribute attr_7, Validation Accuracy: 0.7595
Epoch 8, Attribute attr_8, Validation Accuracy: 0.9619
Epoch 8, Attribute attr_9, Validation Accuracy: 0.9980
Epoch 8, Attribute attr_10, Validation Accuracy: 1.0000


                                                                      

Epoch 9, Attribute attr_1, Validation Accuracy: 0.8289
Epoch 9, Attribute attr_2, Validation Accuracy: 0.8954
Epoch 9, Attribute attr_3, Validation Accuracy: 0.8250
Epoch 9, Attribute attr_4, Validation Accuracy: 0.9453
Epoch 9, Attribute attr_5, Validation Accuracy: 0.9228
Epoch 9, Attribute attr_6, Validation Accuracy: 0.7791
Epoch 9, Attribute attr_7, Validation Accuracy: 0.7722
Epoch 9, Attribute attr_8, Validation Accuracy: 0.9629
Epoch 9, Attribute attr_9, Validation Accuracy: 0.9980
Epoch 9, Attribute attr_10, Validation Accuracy: 1.0000


                                                                       

Epoch 10, Attribute attr_1, Validation Accuracy: 0.8211
Epoch 10, Attribute attr_2, Validation Accuracy: 0.8954
Epoch 10, Attribute attr_3, Validation Accuracy: 0.8250
Epoch 10, Attribute attr_4, Validation Accuracy: 0.9433
Epoch 10, Attribute attr_5, Validation Accuracy: 0.9228
Epoch 10, Attribute attr_6, Validation Accuracy: 0.7849
Epoch 10, Attribute attr_7, Validation Accuracy: 0.7615
Epoch 10, Attribute attr_8, Validation Accuracy: 0.9638
Epoch 10, Attribute attr_9, Validation Accuracy: 0.9980
Epoch 10, Attribute attr_10, Validation Accuracy: 1.0000


                                                        

Final Test Accuracy for Attribute attr_1: 0.8076
Final Test Accuracy for Attribute attr_2: 0.8857
Final Test Accuracy for Attribute attr_3: 0.8320
Final Test Accuracy for Attribute attr_4: 0.9346
Final Test Accuracy for Attribute attr_5: 0.9316
Final Test Accuracy for Attribute attr_6: 0.7695
Final Test Accuracy for Attribute attr_7: 0.7656
Final Test Accuracy for Attribute attr_8: 0.9648
Final Test Accuracy for Attribute attr_9: 0.9941
Final Test Accuracy for Attribute attr_10: 1.0000




In [None]:
women_tshirts_df = women_tshirts_df.drop(columns=['Category', 'len'])
women_tshirts_df = women_tshirts_df.apply(
    lambda col: col.fillna('dummy_value') if col.isna().all() 
    else col.fillna(col.mode()[0]) if col.dtype == 'object' or col.dtype.name == 'category' 
    else col
)

attribute_columns_women_tshirts = [f'attr_{i}' for i in range(1, 11)]
category_mappings_women_tshirts = {}

for col in attribute_columns_women_tshirts:
    women_tshirts_df[col] = pd.Categorical(women_tshirts_df[col])
    category_mappings_women_tshirts[col] = dict(enumerate(women_tshirts_df[col].cat.categories))
    women_tshirts_df[col] = women_tshirts_df[col].cat.codes

women_tshirts_df['image_path'] = '/kaggle/input/visual-taxonomy/train_images/' + women_tshirts_df['id'].apply(lambda x: f"{x:06d}.jpg")

train_df_women_tshirts, temp_df = train_test_split(women_tshirts_df, test_size=0.3, random_state=42)
val_df_women_tshirts, test_df_women_tshirts = train_test_split(temp_df, test_size=0.5, random_state=42)

num_classes_per_attribute_women_tshirts = {
    'attr_1': 7,
    'attr_2': 3,
    'attr_3': 3,
    'attr_4': 3,
    'attr_5': 6,
    'attr_6': 3,
    'attr_7': 2,
    'attr_8': 2,
    'attr_9': 1,
    'attr_10': 1
}

# Model with Dropout and Learning Rate Scheduling for women_tshirts
class MultiOutputModel(nn.Module):
    def __init__(self, base_model, num_classes_per_attribute):
        super(MultiOutputModel, self).__init__()
        self.base_model = base_model

        self.dropout = nn.Dropout(0.3)  # Adding Dropout
        self.heads = nn.ModuleDict({
            attr: nn.Linear(base_model.config.hidden_size, num_classes)
            for attr, num_classes in num_classes_per_attribute.items()
        })

    def forward(self, pixel_values):
        outputs = self.base_model(pixel_values=pixel_values)
        hidden_states = outputs.last_hidden_state
        cls_hidden_states = hidden_states[:, 0, :]
        cls_hidden_states = self.dropout(cls_hidden_states)  # Apply dropout
        logits = {attr: self.heads[attr](cls_hidden_states) for attr in self.heads}
        return logits

base_model = ViTModel.from_pretrained("google/vit-base-patch16-224")
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
model_vit_women_tshirts = MultiOutputModel(base_model, num_classes_per_attribute_women_tshirts)

# Data Augmentation for women_tshirts
data_augmentation = T.Compose([
    T.RandomResizedCrop(224),
    T.RandomHorizontalFlip(),
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    T.ToTensor(),
    T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

class CustomDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        id_str = str(self.df.iloc[idx]["id"]).zfill(6)
        path = self.df.iloc[idx]["image_path"]
        image = Image.open(path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        labels = {}
        
        for attr in attribute_columns_women_tshirts:
            labels[attr] = torch.tensor(self.df.iloc[idx][attr], dtype=torch.long)

        return image, labels

# Initialize datasets and data loaders
train_dataset = CustomDataset(train_df_women_tshirts, transform=data_augmentation)
val_dataset = CustomDataset(val_df_women_tshirts, transform=T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
]))
test_dataset = CustomDataset(test_df_women_tshirts, transform=T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
]))

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_vit_women_tshirts.to(device)

# Optimizer and Learning Rate Scheduler for women_tshirts
optimizer = torch.optim.AdamW(model_vit_women_tshirts.parameters(), lr=5e-5, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
scaler = GradScaler()

# Training Loop for women_tshirts
num_epochs = 10

for epoch in range(num_epochs):
    model_vit_women_tshirts.train()
    running_corrects = {attr: 0 for attr in attribute_columns_women_tshirts}
    total_samples = {attr: 0 for attr in attribute_columns_women_tshirts}

    for images, labels in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}/{num_epochs}", leave=False):
        images = images.to(device)
        optimizer.zero_grad()

        with autocast():
            outputs = model_vit_women_tshirts(pixel_values=images)
            
            losses = {}
            
            for attr in labels.keys():
                loss_fn = nn.CrossEntropyLoss()
                target = labels[attr].long().to(device)
                loss = loss_fn(outputs[attr], target)
                losses[attr] = loss

                # Compute accuracy
                _, preds = torch.max(outputs[attr], 1)
                running_corrects[attr] += torch.sum(preds == target.data).item()
                total_samples[attr] += target.size(0)

            total_loss = sum(losses.values())

        scaler.scale(total_loss).backward()
        scaler.step(optimizer)
        scaler.update()

    # Update learning rate
    scheduler.step()

    # Validation Accuracy
    model_vit_women_tshirts.eval()
    val_corrects = {attr: 0 for attr in attribute_columns_women_tshirts}
    val_samples = {attr: 0 for attr in attribute_columns_women_tshirts}

    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc="Validating", leave=False):
            images = images.to(device)
            outputs = model_vit_women_tshirts(pixel_values=images)

            for attr in labels.keys():
                loss_fn = nn.CrossEntropyLoss()
                target = labels[attr].long().to(device)
                _, preds = torch.max(outputs[attr], 1)
                val_corrects[attr] += torch.sum(preds == target.data).item()
                val_samples[attr] += target.size(0)

    for attr in attribute_columns_women_tshirts:
        val_acc = val_corrects[attr] / val_samples[attr] if val_samples[attr] > 0 else 0
        print(f"Epoch {epoch + 1}, Attribute {attr}, Validation Accuracy: {val_acc:.4f}")

# Final Test Accuracy
model_vit_women_tshirts.eval()
test_corrects = {attr: 0 for attr in attribute_columns_women_tshirts}
test_samples = {attr: 0 for attr in attribute_columns_women_tshirts}

with torch.no_grad():
    for images, labels in tqdm(test_loader, desc="Testing", leave=False):
        images = images.to(device)
        outputs = model_vit_women_tshirts(pixel_values=images)
        
        for attr in labels.keys():
            target = labels[attr].long().to(device)
            _, preds = torch.max(outputs[attr], 1)
            test_corrects[attr] += torch.sum(preds == target.data).item()
            test_samples[attr] += target.size(0)

# Print final test accuracy for each attribute
for attr in attribute_columns_women_tshirts:
    test_acc = test_corrects[attr] / test_samples[attr] if test_samples[attr] > 0 else 0
    print(f"Final Test Accuracy for Attribute {attr}: {test_acc:.4f}")


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  scaler = GradScaler()
  self.pid = os.fork()
  with autocast():
  self.pid = os.fork()
                                                                      

Epoch 1, Attribute attr_1, Validation Accuracy: 0.7592
Epoch 1, Attribute attr_2, Validation Accuracy: 0.8952
Epoch 1, Attribute attr_3, Validation Accuracy: 0.7827
Epoch 1, Attribute attr_4, Validation Accuracy: 0.9574
Epoch 1, Attribute attr_5, Validation Accuracy: 0.6918
Epoch 1, Attribute attr_6, Validation Accuracy: 0.9439
Epoch 1, Attribute attr_7, Validation Accuracy: 0.9613
Epoch 1, Attribute attr_8, Validation Accuracy: 0.9986
Epoch 1, Attribute attr_9, Validation Accuracy: 1.0000
Epoch 1, Attribute attr_10, Validation Accuracy: 1.0000


                                                                      

Epoch 2, Attribute attr_1, Validation Accuracy: 0.7816
Epoch 2, Attribute attr_2, Validation Accuracy: 0.9023
Epoch 2, Attribute attr_3, Validation Accuracy: 0.8075
Epoch 2, Attribute attr_4, Validation Accuracy: 0.9634
Epoch 2, Attribute attr_5, Validation Accuracy: 0.7152
Epoch 2, Attribute attr_6, Validation Accuracy: 0.9442
Epoch 2, Attribute attr_7, Validation Accuracy: 0.9688
Epoch 2, Attribute attr_8, Validation Accuracy: 0.9986
Epoch 2, Attribute attr_9, Validation Accuracy: 1.0000
Epoch 2, Attribute attr_10, Validation Accuracy: 1.0000


                                                                      

Epoch 3, Attribute attr_1, Validation Accuracy: 0.7958
Epoch 3, Attribute attr_2, Validation Accuracy: 0.9070
Epoch 3, Attribute attr_3, Validation Accuracy: 0.8168
Epoch 3, Attribute attr_4, Validation Accuracy: 0.9691
Epoch 3, Attribute attr_5, Validation Accuracy: 0.7248
Epoch 3, Attribute attr_6, Validation Accuracy: 0.9425
Epoch 3, Attribute attr_7, Validation Accuracy: 0.9702
Epoch 3, Attribute attr_8, Validation Accuracy: 0.9986
Epoch 3, Attribute attr_9, Validation Accuracy: 1.0000
Epoch 3, Attribute attr_10, Validation Accuracy: 1.0000


                                                                      

Epoch 4, Attribute attr_1, Validation Accuracy: 0.7997
Epoch 4, Attribute attr_2, Validation Accuracy: 0.9098
Epoch 4, Attribute attr_3, Validation Accuracy: 0.8178
Epoch 4, Attribute attr_4, Validation Accuracy: 0.9698
Epoch 4, Attribute attr_5, Validation Accuracy: 0.7326
Epoch 4, Attribute attr_6, Validation Accuracy: 0.9471
Epoch 4, Attribute attr_7, Validation Accuracy: 0.9698
Epoch 4, Attribute attr_8, Validation Accuracy: 0.9986
Epoch 4, Attribute attr_9, Validation Accuracy: 1.0000
Epoch 4, Attribute attr_10, Validation Accuracy: 1.0000


                                                                      

Epoch 5, Attribute attr_1, Validation Accuracy: 0.7976
Epoch 5, Attribute attr_2, Validation Accuracy: 0.9141
Epoch 5, Attribute attr_3, Validation Accuracy: 0.8143
Epoch 5, Attribute attr_4, Validation Accuracy: 0.9716
Epoch 5, Attribute attr_5, Validation Accuracy: 0.7376
Epoch 5, Attribute attr_6, Validation Accuracy: 0.9489
Epoch 5, Attribute attr_7, Validation Accuracy: 0.9755
Epoch 5, Attribute attr_8, Validation Accuracy: 0.9986
Epoch 5, Attribute attr_9, Validation Accuracy: 1.0000
Epoch 5, Attribute attr_10, Validation Accuracy: 1.0000


                                                                      

Epoch 6, Attribute attr_1, Validation Accuracy: 0.8050
Epoch 6, Attribute attr_2, Validation Accuracy: 0.9158
Epoch 6, Attribute attr_3, Validation Accuracy: 0.8256
Epoch 6, Attribute attr_4, Validation Accuracy: 0.9702
Epoch 6, Attribute attr_5, Validation Accuracy: 0.7326
Epoch 6, Attribute attr_6, Validation Accuracy: 0.9474
Epoch 6, Attribute attr_7, Validation Accuracy: 0.9695
Epoch 6, Attribute attr_8, Validation Accuracy: 0.9986
Epoch 6, Attribute attr_9, Validation Accuracy: 1.0000
Epoch 6, Attribute attr_10, Validation Accuracy: 1.0000


                                                                      

Epoch 7, Attribute attr_1, Validation Accuracy: 0.8008
Epoch 7, Attribute attr_2, Validation Accuracy: 0.9134
Epoch 7, Attribute attr_3, Validation Accuracy: 0.8189
Epoch 7, Attribute attr_4, Validation Accuracy: 0.9698
Epoch 7, Attribute attr_5, Validation Accuracy: 0.7454
Epoch 7, Attribute attr_6, Validation Accuracy: 0.9446
Epoch 7, Attribute attr_7, Validation Accuracy: 0.9673
Epoch 7, Attribute attr_8, Validation Accuracy: 0.9986
Epoch 7, Attribute attr_9, Validation Accuracy: 1.0000
Epoch 7, Attribute attr_10, Validation Accuracy: 1.0000


                                                                      

Epoch 8, Attribute attr_1, Validation Accuracy: 0.8065
Epoch 8, Attribute attr_2, Validation Accuracy: 0.9155
Epoch 8, Attribute attr_3, Validation Accuracy: 0.8313
Epoch 8, Attribute attr_4, Validation Accuracy: 0.9673
Epoch 8, Attribute attr_5, Validation Accuracy: 0.7411
Epoch 8, Attribute attr_6, Validation Accuracy: 0.9467
Epoch 8, Attribute attr_7, Validation Accuracy: 0.9648
Epoch 8, Attribute attr_8, Validation Accuracy: 0.9986
Epoch 8, Attribute attr_9, Validation Accuracy: 1.0000
Epoch 8, Attribute attr_10, Validation Accuracy: 1.0000


                                                                      

Epoch 9, Attribute attr_1, Validation Accuracy: 0.8061
Epoch 9, Attribute attr_2, Validation Accuracy: 0.9165
Epoch 9, Attribute attr_3, Validation Accuracy: 0.8292
Epoch 9, Attribute attr_4, Validation Accuracy: 0.9680
Epoch 9, Attribute attr_5, Validation Accuracy: 0.7422
Epoch 9, Attribute attr_6, Validation Accuracy: 0.9467
Epoch 9, Attribute attr_7, Validation Accuracy: 0.9663
Epoch 9, Attribute attr_8, Validation Accuracy: 0.9986
Epoch 9, Attribute attr_9, Validation Accuracy: 1.0000
Epoch 9, Attribute attr_10, Validation Accuracy: 1.0000


                                                                       

Epoch 10, Attribute attr_1, Validation Accuracy: 0.8047
Epoch 10, Attribute attr_2, Validation Accuracy: 0.9187
Epoch 10, Attribute attr_3, Validation Accuracy: 0.8281
Epoch 10, Attribute attr_4, Validation Accuracy: 0.9680
Epoch 10, Attribute attr_5, Validation Accuracy: 0.7436
Epoch 10, Attribute attr_6, Validation Accuracy: 0.9467
Epoch 10, Attribute attr_7, Validation Accuracy: 0.9663
Epoch 10, Attribute attr_8, Validation Accuracy: 0.9986
Epoch 10, Attribute attr_9, Validation Accuracy: 1.0000
Epoch 10, Attribute attr_10, Validation Accuracy: 1.0000


                                                        

Final Test Accuracy for Attribute attr_1: 0.7998
Final Test Accuracy for Attribute attr_2: 0.9038
Final Test Accuracy for Attribute attr_3: 0.8250
Final Test Accuracy for Attribute attr_4: 0.9681
Final Test Accuracy for Attribute attr_5: 0.7316
Final Test Accuracy for Attribute attr_6: 0.9457
Final Test Accuracy for Attribute attr_7: 0.9673
Final Test Accuracy for Attribute attr_8: 0.9996
Final Test Accuracy for Attribute attr_9: 1.0000
Final Test Accuracy for Attribute attr_10: 1.0000




In [None]:
women_tops_df = women_tops_df.drop(columns=['Category', 'len'])
women_tops_df = women_tops_df.apply(
    lambda col: col.fillna('dummy_value') if col.isna().all() 
    else col.fillna(col.mode()[0]) if col.dtype == 'object' or col.dtype.name == 'category' 
    else col
)

attribute_columns_women_tops = [f'attr_{i}' for i in range(1, 11)]
category_mappings_women_tops = {}

for col in attribute_columns_women_tops:
    women_tops_df[col] = pd.Categorical(women_tops_df[col])
    category_mappings_women_tops[col] = dict(enumerate(women_tops_df[col].cat.categories))
    women_tops_df[col] = women_tops_df[col].cat.codes

women_tops_df['image_path'] = '/kaggle/input/visual-taxonomy/train_images/' + women_tops_df['id'].apply(lambda x: f"{x:06d}.jpg")

train_df_women_tops, temp_df = train_test_split(women_tops_df, test_size=0.3, random_state=42)
val_df_women_tops, test_df_women_tops = train_test_split(temp_df, test_size=0.5, random_state=42)

num_classes_per_attribute_women_tops = {
    'attr_1': 12,
    'attr_2': 4,
    'attr_3': 2,
    'attr_4': 7,
    'attr_5': 2,
    'attr_6': 3,
    'attr_7': 6,
    'attr_8': 4,
    'attr_9': 4,
    'attr_10': 6
}

# Model with Dropout and Learning Rate Scheduling for women_tops
class MultiOutputModel(nn.Module):
    def __init__(self, base_model, num_classes_per_attribute):
        super(MultiOutputModel, self).__init__()
        self.base_model = base_model

        self.dropout = nn.Dropout(0.3)  # Adding Dropout
        self.heads = nn.ModuleDict({
            attr: nn.Linear(base_model.config.hidden_size, num_classes)
            for attr, num_classes in num_classes_per_attribute.items()
        })

    def forward(self, pixel_values):
        outputs = self.base_model(pixel_values=pixel_values)
        hidden_states = outputs.last_hidden_state
        cls_hidden_states = hidden_states[:, 0, :]
        cls_hidden_states = self.dropout(cls_hidden_states)  # Apply dropout
        logits = {attr: self.heads[attr](cls_hidden_states) for attr in self.heads}
        return logits

base_model = ViTModel.from_pretrained("google/vit-base-patch16-224")
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
model_vit_women_tops = MultiOutputModel(base_model, num_classes_per_attribute_women_tops)

# Data Augmentation for women_tops
data_augmentation = T.Compose([
    T.RandomResizedCrop(224),
    T.RandomHorizontalFlip(),
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    T.ToTensor(),
    T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

class CustomDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        id_str = str(self.df.iloc[idx]["id"]).zfill(6)
        path = self.df.iloc[idx]["image_path"]
        image = Image.open(path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        labels = {}
        
        for attr in attribute_columns_women_tops:
            labels[attr] = torch.tensor(self.df.iloc[idx][attr], dtype=torch.long)

        return image, labels

# Initialize datasets and data loaders
train_dataset = CustomDataset(train_df_women_tops, transform=data_augmentation)
val_dataset = CustomDataset(val_df_women_tops, transform=T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
]))
test_dataset = CustomDataset(test_df_women_tops, transform=T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
]))

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_vit_women_tops.to(device)

# Optimizer and Learning Rate Scheduler for women_tops
optimizer = torch.optim.AdamW(model_vit_women_tops.parameters(), lr=5e-5, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
scaler = GradScaler()

# Training Loop for women_tops
num_epochs = 10

for epoch in range(num_epochs):
    model_vit_women_tops.train()
    running_corrects = {attr: 0 for attr in attribute_columns_women_tops}
    total_samples = {attr: 0 for attr in attribute_columns_women_tops}

    for images, labels in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}/{num_epochs}", leave=False):
        images = images.to(device)
        optimizer.zero_grad()

        with autocast():
            outputs = model_vit_women_tops(pixel_values=images)
            
            losses = {}
            
            for attr in labels.keys():
                loss_fn = nn.CrossEntropyLoss()
                target = labels[attr].long().to(device)
                loss = loss_fn(outputs[attr], target)
                losses[attr] = loss

                # Compute accuracy
                _, preds = torch.max(outputs[attr], 1)
                running_corrects[attr] += torch.sum(preds == target.data).item()
                total_samples[attr] += target.size(0)

            total_loss = sum(losses.values())

        scaler.scale(total_loss).backward()
        scaler.step(optimizer)
        scaler.update()

    # Update learning rate
    scheduler.step()

    # Validation Accuracy
    model_vit_women_tops.eval()
    val_corrects = {attr: 0 for attr in attribute_columns_women_tops}
    val_samples = {attr: 0 for attr in attribute_columns_women_tops}

    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc="Validating", leave=False):
            images = images.to(device)
            outputs = model_vit_women_tops(pixel_values=images)

            for attr in labels.keys():
                loss_fn = nn.CrossEntropyLoss()
                target = labels[attr].long().to(device)
                _, preds = torch.max(outputs[attr], 1)
                val_corrects[attr] += torch.sum(preds == target.data).item()
                val_samples[attr] += target.size(0)

    for attr in attribute_columns_women_tops:
        val_acc = val_corrects[attr] / val_samples[attr] if val_samples[attr] > 0 else 0
        print(f"Epoch {epoch + 1}, Attribute {attr}, Validation Accuracy: {val_acc:.4f}")

# Final Test Accuracy
model_vit_women_tops.eval()
test_corrects = {attr: 0 for attr in attribute_columns_women_tops}
test_samples = {attr: 0 for attr in attribute_columns_women_tops}

with torch.no_grad():
    for images, labels in tqdm(test_loader, desc="Testing", leave=False):
        images = images.to(device)
        outputs = model_vit_women_tops(pixel_values=images)
        
        for attr in labels.keys():
            target = labels[attr].long().to(device)
            _, preds = torch.max(outputs[attr], 1)
            test_corrects[attr] += torch.sum(preds == target.data).item()
            test_samples[attr] += target.size(0)

# Print final test accuracy for each attribute
for attr in attribute_columns_women_tops:
    test_acc = test_corrects[attr] / test_samples[attr] if test_samples[attr] > 0 else 0
    print(f"Final Test Accuracy for Attribute {attr}: {test_acc:.4f}")


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  scaler = GradScaler()
  self.pid = os.fork()
  with autocast():
  self.pid = os.fork()
                                                                      

Epoch 1, Attribute attr_1, Validation Accuracy: 0.5851
Epoch 1, Attribute attr_2, Validation Accuracy: 0.7425
Epoch 1, Attribute attr_3, Validation Accuracy: 0.8446
Epoch 1, Attribute attr_4, Validation Accuracy: 0.6875
Epoch 1, Attribute attr_5, Validation Accuracy: 0.9902
Epoch 1, Attribute attr_6, Validation Accuracy: 0.8923
Epoch 1, Attribute attr_7, Validation Accuracy: 0.8334
Epoch 1, Attribute attr_8, Validation Accuracy: 0.8127
Epoch 1, Attribute attr_9, Validation Accuracy: 0.8225
Epoch 1, Attribute attr_10, Validation Accuracy: 0.8015


                                                                      

Epoch 2, Attribute attr_1, Validation Accuracy: 0.6226
Epoch 2, Attribute attr_2, Validation Accuracy: 0.7538
Epoch 2, Attribute attr_3, Validation Accuracy: 0.8548
Epoch 2, Attribute attr_4, Validation Accuracy: 0.7078
Epoch 2, Attribute attr_5, Validation Accuracy: 0.9902
Epoch 2, Attribute attr_6, Validation Accuracy: 0.8990
Epoch 2, Attribute attr_7, Validation Accuracy: 0.8320
Epoch 2, Attribute attr_8, Validation Accuracy: 0.8281
Epoch 2, Attribute attr_9, Validation Accuracy: 0.8348
Epoch 2, Attribute attr_10, Validation Accuracy: 0.8162


                                                                      

Epoch 3, Attribute attr_1, Validation Accuracy: 0.6342
Epoch 3, Attribute attr_2, Validation Accuracy: 0.7576
Epoch 3, Attribute attr_3, Validation Accuracy: 0.8523
Epoch 3, Attribute attr_4, Validation Accuracy: 0.7155
Epoch 3, Attribute attr_5, Validation Accuracy: 0.9902
Epoch 3, Attribute attr_6, Validation Accuracy: 0.8923
Epoch 3, Attribute attr_7, Validation Accuracy: 0.8394
Epoch 3, Attribute attr_8, Validation Accuracy: 0.8225
Epoch 3, Attribute attr_9, Validation Accuracy: 0.8355
Epoch 3, Attribute attr_10, Validation Accuracy: 0.8197


                                                                      

Epoch 4, Attribute attr_1, Validation Accuracy: 0.6563
Epoch 4, Attribute attr_2, Validation Accuracy: 0.7618
Epoch 4, Attribute attr_3, Validation Accuracy: 0.8583
Epoch 4, Attribute attr_4, Validation Accuracy: 0.7226
Epoch 4, Attribute attr_5, Validation Accuracy: 0.9902
Epoch 4, Attribute attr_6, Validation Accuracy: 0.8997
Epoch 4, Attribute attr_7, Validation Accuracy: 0.8341
Epoch 4, Attribute attr_8, Validation Accuracy: 0.8236
Epoch 4, Attribute attr_9, Validation Accuracy: 0.8425
Epoch 4, Attribute attr_10, Validation Accuracy: 0.8208


                                                                      

Epoch 5, Attribute attr_1, Validation Accuracy: 0.6657
Epoch 5, Attribute attr_2, Validation Accuracy: 0.7685
Epoch 5, Attribute attr_3, Validation Accuracy: 0.8604
Epoch 5, Attribute attr_4, Validation Accuracy: 0.7306
Epoch 5, Attribute attr_5, Validation Accuracy: 0.9902
Epoch 5, Attribute attr_6, Validation Accuracy: 0.9004
Epoch 5, Attribute attr_7, Validation Accuracy: 0.8474
Epoch 5, Attribute attr_8, Validation Accuracy: 0.8390
Epoch 5, Attribute attr_9, Validation Accuracy: 0.8474
Epoch 5, Attribute attr_10, Validation Accuracy: 0.8208


                                                                      

Epoch 6, Attribute attr_1, Validation Accuracy: 0.6717
Epoch 6, Attribute attr_2, Validation Accuracy: 0.7643
Epoch 6, Attribute attr_3, Validation Accuracy: 0.8632
Epoch 6, Attribute attr_4, Validation Accuracy: 0.7285
Epoch 6, Attribute attr_5, Validation Accuracy: 0.9902
Epoch 6, Attribute attr_6, Validation Accuracy: 0.9004
Epoch 6, Attribute attr_7, Validation Accuracy: 0.8401
Epoch 6, Attribute attr_8, Validation Accuracy: 0.8295
Epoch 6, Attribute attr_9, Validation Accuracy: 0.8555
Epoch 6, Attribute attr_10, Validation Accuracy: 0.8225


                                                                      

Epoch 7, Attribute attr_1, Validation Accuracy: 0.6780
Epoch 7, Attribute attr_2, Validation Accuracy: 0.7675
Epoch 7, Attribute attr_3, Validation Accuracy: 0.8639
Epoch 7, Attribute attr_4, Validation Accuracy: 0.7327
Epoch 7, Attribute attr_5, Validation Accuracy: 0.9902
Epoch 7, Attribute attr_6, Validation Accuracy: 0.8986
Epoch 7, Attribute attr_7, Validation Accuracy: 0.8450
Epoch 7, Attribute attr_8, Validation Accuracy: 0.8337
Epoch 7, Attribute attr_9, Validation Accuracy: 0.8544
Epoch 7, Attribute attr_10, Validation Accuracy: 0.8243


                                                                      

Epoch 8, Attribute attr_1, Validation Accuracy: 0.6812
Epoch 8, Attribute attr_2, Validation Accuracy: 0.7675
Epoch 8, Attribute attr_3, Validation Accuracy: 0.8646
Epoch 8, Attribute attr_4, Validation Accuracy: 0.7348
Epoch 8, Attribute attr_5, Validation Accuracy: 0.9902
Epoch 8, Attribute attr_6, Validation Accuracy: 0.8990
Epoch 8, Attribute attr_7, Validation Accuracy: 0.8464
Epoch 8, Attribute attr_8, Validation Accuracy: 0.8337
Epoch 8, Attribute attr_9, Validation Accuracy: 0.8513
Epoch 8, Attribute attr_10, Validation Accuracy: 0.8285


                                                                      

Epoch 9, Attribute attr_1, Validation Accuracy: 0.6812
Epoch 9, Attribute attr_2, Validation Accuracy: 0.7710
Epoch 9, Attribute attr_3, Validation Accuracy: 0.8643
Epoch 9, Attribute attr_4, Validation Accuracy: 0.7376
Epoch 9, Attribute attr_5, Validation Accuracy: 0.9905
Epoch 9, Attribute attr_6, Validation Accuracy: 0.8990
Epoch 9, Attribute attr_7, Validation Accuracy: 0.8478
Epoch 9, Attribute attr_8, Validation Accuracy: 0.8351
Epoch 9, Attribute attr_9, Validation Accuracy: 0.8513
Epoch 9, Attribute attr_10, Validation Accuracy: 0.8278


                                                                       

Epoch 10, Attribute attr_1, Validation Accuracy: 0.6808
Epoch 10, Attribute attr_2, Validation Accuracy: 0.7699
Epoch 10, Attribute attr_3, Validation Accuracy: 0.8636
Epoch 10, Attribute attr_4, Validation Accuracy: 0.7369
Epoch 10, Attribute attr_5, Validation Accuracy: 0.9905
Epoch 10, Attribute attr_6, Validation Accuracy: 0.9011
Epoch 10, Attribute attr_7, Validation Accuracy: 0.8488
Epoch 10, Attribute attr_8, Validation Accuracy: 0.8365
Epoch 10, Attribute attr_9, Validation Accuracy: 0.8499
Epoch 10, Attribute attr_10, Validation Accuracy: 0.8271


                                                        

Final Test Accuracy for Attribute attr_1: 0.6734
Final Test Accuracy for Attribute attr_2: 0.7724
Final Test Accuracy for Attribute attr_3: 0.8716
Final Test Accuracy for Attribute attr_4: 0.7461
Final Test Accuracy for Attribute attr_5: 0.9895
Final Test Accuracy for Attribute attr_6: 0.9021
Final Test Accuracy for Attribute attr_7: 0.8558
Final Test Accuracy for Attribute attr_8: 0.8478
Final Test Accuracy for Attribute attr_9: 0.8365
Final Test Accuracy for Attribute attr_10: 0.8295




In [None]:
men_tshirts_df = men_tshirts_df.drop(columns=['Category', 'len'])
men_tshirts_df = men_tshirts_df.apply(
    lambda col: col.fillna('dummy_value') if col.isna().all() 
    else col.fillna(col.mode()[0]) if col.dtype == 'object' or col.dtype.name == 'category' 
    else col
)

attribute_columns_men_tshirts = [f'attr_{i}' for i in range(1, 11)]
category_mappings_men_tshirts = {}

for col in attribute_columns_men_tshirts:
    men_tshirts_df[col] = pd.Categorical(men_tshirts_df[col])
    category_mappings_men_tshirts[col] = dict(enumerate(men_tshirts_df[col].cat.categories))
    men_tshirts_df[col] = men_tshirts_df[col].cat.codes

men_tshirts_df['image_path'] = '/kaggle/input/visual-taxonomy/train_images/' + men_tshirts_df['id'].apply(lambda x: f"{x:06d}.jpg")

train_df_men_tshirts, temp_df = train_test_split(men_tshirts_df, test_size=0.3, random_state=42)
val_df_men_tshirts, test_df_men_tshirts = train_test_split(temp_df, test_size=0.5, random_state=42)

num_classes_per_attribute_men_tshirts = {
    'attr_1': 4,
    'attr_2': 2,
    'attr_3': 2,
    'attr_4': 3,
    'attr_5': 3,
    'attr_6': 1,
    'attr_7': 1,
    'attr_8': 1,
    'attr_9': 1,
    'attr_10': 1
}

# Model with Dropout and Learning Rate Scheduling for men_tshirts
class MultiOutputModel(nn.Module):
    def __init__(self, base_model, num_classes_per_attribute):
        super(MultiOutputModel, self).__init__()
        self.base_model = base_model

        self.dropout = nn.Dropout(0.3)  # Adding Dropout
        self.heads = nn.ModuleDict({
            attr: nn.Linear(base_model.config.hidden_size, num_classes)
            for attr, num_classes in num_classes_per_attribute.items()
        })

    def forward(self, pixel_values):
        outputs = self.base_model(pixel_values=pixel_values)
        hidden_states = outputs.last_hidden_state
        cls_hidden_states = hidden_states[:, 0, :]
        cls_hidden_states = self.dropout(cls_hidden_states)  # Apply dropout
        logits = {attr: self.heads[attr](cls_hidden_states) for attr in self.heads}
        return logits

base_model = ViTModel.from_pretrained("google/vit-base-patch16-224")
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
model_vit_men_tshirts = MultiOutputModel(base_model, num_classes_per_attribute_men_tshirts)

# Data Augmentation for men_tshirts
data_augmentation = T.Compose([
    T.RandomResizedCrop(224),
    T.RandomHorizontalFlip(),
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    T.ToTensor(),
    T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

class CustomDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        id_str = str(self.df.iloc[idx]["id"]).zfill(6)
        path = self.df.iloc[idx]["image_path"]
        image = Image.open(path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        labels = {}
        
        for attr in attribute_columns_men_tshirts:
            labels[attr] = torch.tensor(self.df.iloc[idx][attr], dtype=torch.long)

        return image, labels

# Initialize datasets and data loaders
train_dataset = CustomDataset(train_df_men_tshirts, transform=data_augmentation)
val_dataset = CustomDataset(val_df_men_tshirts, transform=T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
]))
test_dataset = CustomDataset(test_df_men_tshirts, transform=T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
]))

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

# Device Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_vit_men_tshirts.to(device)

# Optimizer and Learning Rate Scheduler for men_tshirts
optimizer = torch.optim.AdamW(model_vit_men_tshirts.parameters(), lr=5e-5, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
scaler = GradScaler()

# Training Loop for men_tshirts
num_epochs = 10

for epoch in range(num_epochs):
    model_vit_men_tshirts.train()
    running_corrects = {attr: 0 for attr in attribute_columns_men_tshirts}
    total_samples = {attr: 0 for attr in attribute_columns_men_tshirts}

    for images, labels in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}/{num_epochs}", leave=False):
        images = images.to(device)
        optimizer.zero_grad()

        with autocast():
            outputs = model_vit_men_tshirts(pixel_values=images)
            
            losses = {}
            
            for attr in labels.keys():
                loss_fn = nn.CrossEntropyLoss()
                target = labels[attr].long().to(device)
                loss = loss_fn(outputs[attr], target)
                losses[attr] = loss

                # Compute accuracy
                _, preds = torch.max(outputs[attr], 1)
                running_corrects[attr] += torch.sum(preds == target.data).item()
                total_samples[attr] += target.size(0)

            total_loss = sum(losses.values())

        scaler.scale(total_loss).backward()
        scaler.step(optimizer)
        scaler.update()

    # Update learning rate
    scheduler.step()

    # Validation Accuracy
    model_vit_men_tshirts.eval()
    val_corrects = {attr: 0 for attr in attribute_columns_men_tshirts}
    val_samples = {attr: 0 for attr in attribute_columns_men_tshirts}

    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc="Validating", leave=False):
            images = images.to(device)
            outputs = model_vit_men_tshirts(pixel_values=images)

            for attr in labels.keys():
                loss_fn = nn.CrossEntropyLoss()
                target = labels[attr].long().to(device)
                _, preds = torch.max(outputs[attr], 1)
                val_corrects[attr] += torch.sum(preds == target.data).item()
                val_samples[attr] += target.size(0)

    for attr in attribute_columns_men_tshirts:
        val_acc = val_corrects[attr] / val_samples[attr] if val_samples[attr] > 0 else 0
        print(f"Epoch {epoch + 1}, Attribute {attr}, Validation Accuracy: {val_acc:.4f}")

# Final Test Accuracy
model_vit_men_tshirts.eval()
test_corrects = {attr: 0 for attr in attribute_columns_men_tshirts}
test_samples = {attr: 0 for attr in attribute_columns_men_tshirts}

with torch.no_grad():
    for images, labels in tqdm(test_loader, desc="Testing", leave=False):
        images = images.to(device)
        outputs = model_vit_men_tshirts(pixel_values=images)
        
        for attr in labels.keys():
            target = labels[attr].long().to(device)
            _, preds = torch.max(outputs[attr], 1)
            test_corrects[attr] += torch.sum(preds == target.data).item()
            test_samples[attr] += target.size(0)

# Print final test accuracy for each attribute
for attr in attribute_columns_men_tshirts:
    test_acc = test_corrects[attr] / test_samples[attr] if test_samples[attr] > 0 else 0
    print(f"Final Test Accuracy for Attribute {attr}: {test_acc:.4f}")


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  scaler = GradScaler()
  self.pid = os.fork()
  with autocast():
  self.pid = os.fork()
                                                                      

Epoch 1, Attribute attr_1, Validation Accuracy: 0.6165
Epoch 1, Attribute attr_2, Validation Accuracy: 0.8798
Epoch 1, Attribute attr_3, Validation Accuracy: 0.8376
Epoch 1, Attribute attr_4, Validation Accuracy: 0.7596
Epoch 1, Attribute attr_5, Validation Accuracy: 0.9633
Epoch 1, Attribute attr_6, Validation Accuracy: 1.0000
Epoch 1, Attribute attr_7, Validation Accuracy: 1.0000
Epoch 1, Attribute attr_8, Validation Accuracy: 1.0000
Epoch 1, Attribute attr_9, Validation Accuracy: 1.0000
Epoch 1, Attribute attr_10, Validation Accuracy: 1.0000


                                                                      

Epoch 2, Attribute attr_1, Validation Accuracy: 0.6183
Epoch 2, Attribute attr_2, Validation Accuracy: 0.8642
Epoch 2, Attribute attr_3, Validation Accuracy: 0.8257
Epoch 2, Attribute attr_4, Validation Accuracy: 0.7119
Epoch 2, Attribute attr_5, Validation Accuracy: 0.9771
Epoch 2, Attribute attr_6, Validation Accuracy: 1.0000
Epoch 2, Attribute attr_7, Validation Accuracy: 1.0000
Epoch 2, Attribute attr_8, Validation Accuracy: 1.0000
Epoch 2, Attribute attr_9, Validation Accuracy: 1.0000
Epoch 2, Attribute attr_10, Validation Accuracy: 1.0000


                                                                      

Epoch 3, Attribute attr_1, Validation Accuracy: 0.6284
Epoch 3, Attribute attr_2, Validation Accuracy: 0.8826
Epoch 3, Attribute attr_3, Validation Accuracy: 0.8468
Epoch 3, Attribute attr_4, Validation Accuracy: 0.7734
Epoch 3, Attribute attr_5, Validation Accuracy: 0.9752
Epoch 3, Attribute attr_6, Validation Accuracy: 1.0000
Epoch 3, Attribute attr_7, Validation Accuracy: 1.0000
Epoch 3, Attribute attr_8, Validation Accuracy: 1.0000
Epoch 3, Attribute attr_9, Validation Accuracy: 1.0000
Epoch 3, Attribute attr_10, Validation Accuracy: 1.0000


                                                                      

Epoch 4, Attribute attr_1, Validation Accuracy: 0.6431
Epoch 4, Attribute attr_2, Validation Accuracy: 0.8835
Epoch 4, Attribute attr_3, Validation Accuracy: 0.8450
Epoch 4, Attribute attr_4, Validation Accuracy: 0.7881
Epoch 4, Attribute attr_5, Validation Accuracy: 0.9725
Epoch 4, Attribute attr_6, Validation Accuracy: 1.0000
Epoch 4, Attribute attr_7, Validation Accuracy: 1.0000
Epoch 4, Attribute attr_8, Validation Accuracy: 1.0000
Epoch 4, Attribute attr_9, Validation Accuracy: 1.0000
Epoch 4, Attribute attr_10, Validation Accuracy: 1.0000


                                                                      

Epoch 5, Attribute attr_1, Validation Accuracy: 0.6294
Epoch 5, Attribute attr_2, Validation Accuracy: 0.8853
Epoch 5, Attribute attr_3, Validation Accuracy: 0.8349
Epoch 5, Attribute attr_4, Validation Accuracy: 0.7826
Epoch 5, Attribute attr_5, Validation Accuracy: 0.9716
Epoch 5, Attribute attr_6, Validation Accuracy: 1.0000
Epoch 5, Attribute attr_7, Validation Accuracy: 1.0000
Epoch 5, Attribute attr_8, Validation Accuracy: 1.0000
Epoch 5, Attribute attr_9, Validation Accuracy: 1.0000
Epoch 5, Attribute attr_10, Validation Accuracy: 1.0000


                                                                      

Epoch 6, Attribute attr_1, Validation Accuracy: 0.6560
Epoch 6, Attribute attr_2, Validation Accuracy: 0.8789
Epoch 6, Attribute attr_3, Validation Accuracy: 0.8330
Epoch 6, Attribute attr_4, Validation Accuracy: 0.7954
Epoch 6, Attribute attr_5, Validation Accuracy: 0.9706
Epoch 6, Attribute attr_6, Validation Accuracy: 1.0000
Epoch 6, Attribute attr_7, Validation Accuracy: 1.0000
Epoch 6, Attribute attr_8, Validation Accuracy: 1.0000
Epoch 6, Attribute attr_9, Validation Accuracy: 1.0000
Epoch 6, Attribute attr_10, Validation Accuracy: 1.0000


                                                                      

Epoch 7, Attribute attr_1, Validation Accuracy: 0.6404
Epoch 7, Attribute attr_2, Validation Accuracy: 0.8716
Epoch 7, Attribute attr_3, Validation Accuracy: 0.8349
Epoch 7, Attribute attr_4, Validation Accuracy: 0.7789
Epoch 7, Attribute attr_5, Validation Accuracy: 0.9716
Epoch 7, Attribute attr_6, Validation Accuracy: 1.0000
Epoch 7, Attribute attr_7, Validation Accuracy: 1.0000
Epoch 7, Attribute attr_8, Validation Accuracy: 1.0000
Epoch 7, Attribute attr_9, Validation Accuracy: 1.0000
Epoch 7, Attribute attr_10, Validation Accuracy: 1.0000


                                                                      

Epoch 8, Attribute attr_1, Validation Accuracy: 0.6468
Epoch 8, Attribute attr_2, Validation Accuracy: 0.8697
Epoch 8, Attribute attr_3, Validation Accuracy: 0.8367
Epoch 8, Attribute attr_4, Validation Accuracy: 0.7972
Epoch 8, Attribute attr_5, Validation Accuracy: 0.9716
Epoch 8, Attribute attr_6, Validation Accuracy: 1.0000
Epoch 8, Attribute attr_7, Validation Accuracy: 1.0000
Epoch 8, Attribute attr_8, Validation Accuracy: 1.0000
Epoch 8, Attribute attr_9, Validation Accuracy: 1.0000
Epoch 8, Attribute attr_10, Validation Accuracy: 1.0000


                                                                      

Epoch 9, Attribute attr_1, Validation Accuracy: 0.6541
Epoch 9, Attribute attr_2, Validation Accuracy: 0.8734
Epoch 9, Attribute attr_3, Validation Accuracy: 0.8294
Epoch 9, Attribute attr_4, Validation Accuracy: 0.7862
Epoch 9, Attribute attr_5, Validation Accuracy: 0.9706
Epoch 9, Attribute attr_6, Validation Accuracy: 1.0000
Epoch 9, Attribute attr_7, Validation Accuracy: 1.0000
Epoch 9, Attribute attr_8, Validation Accuracy: 1.0000
Epoch 9, Attribute attr_9, Validation Accuracy: 1.0000
Epoch 9, Attribute attr_10, Validation Accuracy: 1.0000


                                                                       

Epoch 10, Attribute attr_1, Validation Accuracy: 0.6505
Epoch 10, Attribute attr_2, Validation Accuracy: 0.8716
Epoch 10, Attribute attr_3, Validation Accuracy: 0.8303
Epoch 10, Attribute attr_4, Validation Accuracy: 0.7908
Epoch 10, Attribute attr_5, Validation Accuracy: 0.9725
Epoch 10, Attribute attr_6, Validation Accuracy: 1.0000
Epoch 10, Attribute attr_7, Validation Accuracy: 1.0000
Epoch 10, Attribute attr_8, Validation Accuracy: 1.0000
Epoch 10, Attribute attr_9, Validation Accuracy: 1.0000
Epoch 10, Attribute attr_10, Validation Accuracy: 1.0000


                                                        

Final Test Accuracy for Attribute attr_1: 0.6590
Final Test Accuracy for Attribute attr_2: 0.8698
Final Test Accuracy for Attribute attr_3: 0.8405
Final Test Accuracy for Attribute attr_4: 0.7984
Final Test Accuracy for Attribute attr_5: 0.9578
Final Test Accuracy for Attribute attr_6: 1.0000
Final Test Accuracy for Attribute attr_7: 1.0000
Final Test Accuracy for Attribute attr_8: 1.0000
Final Test Accuracy for Attribute attr_9: 1.0000
Final Test Accuracy for Attribute attr_10: 1.0000




In [21]:
# Save the model
torch.save(model_vit_men_tshirts.state_dict(), 'model_vit_men_tshirts2.pth')
torch.save(model_vit_women_tops.state_dict(), 'model_vit_women_tops2.pth')
torch.save(model_vit_women_tshirts.state_dict(), 'model_vit_women_tshirts2.pth')
torch.save(model_vit_kurtis.state_dict(), 'model_vit_kurtis2.pth')
torch.save(model_vit_sarees.state_dict(), 'model_vit_sarees2.pth')


In [22]:
def predict(image_path, model, category_mappings, processor):
    # Load the image
    image = Image.open(image_path).convert("RGB")
    
    # Preprocess the image
    inputs = processor(images=image, return_tensors="pt").to(device)

    # Run the model to get predictions
    with torch.no_grad():
        outputs = model(pixel_values=inputs["pixel_values"].to(device))

    predictions = {}

    # Process the outputs for each attribute
    for attr, output in outputs.items():
        probabilities = torch.softmax(output, dim=-1).cpu().numpy()
        predicted_class_index = np.argmax(probabilities, axis=-1)[0]
        
        # Check if index exists in category mapping
        if predicted_class_index in category_mappings[attr]:
            predictions[attr] = category_mappings[attr][predicted_class_index]
        else:
            # Handle unmapped indices
            print(category_mappings, attr, predicted_class_index)
            predictions[attr] = 'unknown'  # or any default value you prefer

    return predictions

In [None]:
from tqdm import tqdm
import contextlib
from IPython.display import HTML
from tensorflow.keras.preprocessing.image import load_img, img_to_array
import pandas as pd
import io

categories_number = {
    'Men Tshirts': 5, 
    'Sarees': 10, 
    'Kurtis': 9, 
    'Women Tshirts': 8, 
    'Women Tops & Tunics': 10
}

model_selection = {
    'Men Tshirts': model_vit_men_tshirts,
    'Sarees': model_vit_sarees,
    'Kurtis': model_vit_kurtis,
    'Women Tshirts': model_vit_women_tshirts,
    'Women Tops & Tunics': model_vit_women_tops
}

mapping_selection = {
    'Men Tshirts': category_mappings_men_tshirts,
    'Sarees': category_mappings_sarees,
    'Kurtis': category_mappings_kurtis,
    'Women Tshirts': category_mappings_women_tshirts,
    'Women Tops & Tunics': category_mappings_women_tops
}

final_results = []

for idx, row in tqdm(test_df.iterrows(), total=len(test_df), desc="Processing Images"):
    category = row['Category']
    
    with contextlib.redirect_stdout(io.StringIO()):
        image_path = '/kaggle/input/visual-taxonomy/test_images/' + "{:06}".format(row['id']) + '.jpg'
        predictions = predict(image_path, model_selection[category], mapping_selection[category], processor)
    
    result = {
        'id': row['id'],
        'Category': category,
        'len': categories_number[category],  
        **predictions  
    }
    
    final_results.append(result)

final_df = pd.DataFrame(final_results)

final_df.to_csv('final_predictions_vit2.csv', index=False)

print("Final CSV generated and saved as 'final_predictions_vit2.csv'.")

Processing Images: 100%|██████████| 30205/30205 [12:47<00:00, 39.35it/s]


Final CSV generated and saved as 'final_predictions_vit2.csv'.


**Predicting for train set**

In [5]:
def predict(image_path, model, category_mappings, processor):
    # Load the image
    image = Image.open(image_path).convert("RGB")
    
    # Preprocess the image
    inputs = processor(images=image, return_tensors="pt").to(device)

    # Run the model to get predictions
    with torch.no_grad():
        outputs = model(pixel_values=inputs["pixel_values"].to(device))

    predictions = {}

    # Process the outputs for each attribute
    for attr, output in outputs.items():
        probabilities = torch.softmax(output, dim=-1).cpu().numpy()
        predicted_class_index = np.argmax(probabilities, axis=-1)[0]
        
        # Check if index exists in category mapping
        if predicted_class_index in category_mappings[attr]:
            predictions[attr] = category_mappings[attr][predicted_class_index]
        else:
            # Handle unmapped indices
            print(category_mappings, attr, predicted_class_index)
            predictions[attr] = 'unknown'  # or any default value you prefer

    return predictions

In [7]:
category_mappings_men_tshirts = {'attr_1': {0: 'black', 1: 'default', 2: 'multicolor', 3: 'white'}, 'attr_2': {0: 'polo', 1: 'round'}, 'attr_3': {0: 'printed', 1: 'solid'}, 'attr_4': {0: 'default', 1: 'solid', 2: 'typography'}, 'attr_5': {0: 'long sleeves', 1: 'short sleeves'}, 'attr_6': {0: 'dummy_value'}, 'attr_7': {0: 'dummy_value'}, 'attr_8': {0: 'dummy_value'}, 'attr_9': {0: 'dummy_value'}, 'attr_10': {0: 'dummy_value'}}
category_mappings_sarees = {'attr_1': {0: 'default', 1: 'same as border', 2: 'same as saree', 3: 'solid'}, 'attr_2': {0: 'default', 1: 'no border', 2: 'solid', 3: 'temple border', 4: 'woven design', 5: 'zari'}, 'attr_3': {0: 'big border', 1: 'no border', 2: 'small border'}, 'attr_4': {0: 'cream', 1: 'default', 2: 'green', 3: 'multicolor', 4: 'navy blue', 5: 'pink', 6: 'white', 7: 'yellow'}, 'attr_5': {0: 'daily', 1: 'party', 2: 'traditional', 3: 'wedding'}, 'attr_6': {0: 'default', 1: 'jacquard', 2: 'tassels and latkans'}, 'attr_7': {0: 'default', 1: 'same as saree', 2: 'woven design', 3: 'zari woven'}, 'attr_8': {0: 'default', 1: 'printed', 2: 'solid', 3: 'woven design', 4: 'zari woven'}, 'attr_9': {0: 'applique', 1: 'botanical', 2: 'checked', 3: 'default', 4: 'elephant', 5: 'ethnic motif', 6: 'floral', 7: 'peacock', 8: 'solid'}, 'attr_10': {0: 'no', 1: 'yes'}}
category_mappings_kurtis = {'attr_1': {0: 'black', 1: 'blue', 2: 'green', 3: 'grey', 4: 'maroon', 5: 'multicolor', 6: 'navy blue', 7: 'orange', 8: 'pink', 9: 'purple', 10: 'red', 11: 'white', 12: 'yellow'}, 'attr_2': {0: 'a-line', 1: 'straight'}, 'attr_3': {0: 'calf length', 1: 'knee length'}, 'attr_4': {0: 'daily', 1: 'party'}, 'attr_5': {0: 'default', 1: 'net'}, 'attr_6': {0: 'default', 1: 'solid'}, 'attr_7': {0: 'default', 1: 'solid'}, 'attr_8': {0: 'short sleeves', 1: 'sleeveless', 2: 'three-quarter sleeves'}, 'attr_9': {0: 'regular', 1: 'sleeveless'}, 'attr_10': {0: 'dummy_value'}}
category_mappings_women_tshirts = {'attr_1': {0: 'black', 1: 'default', 2: 'maroon', 3: 'multicolor', 4: 'pink', 5: 'white', 6: 'yellow'}, 'attr_2': {0: 'boxy', 1: 'loose', 2: 'regular'}, 'attr_3': {0: 'crop', 1: 'long', 2: 'regular'}, 'attr_4': {0: 'default', 1: 'printed', 2: 'solid'}, 'attr_5': {0: 'default', 1: 'funky print', 2: 'graphic', 3: 'quirky', 4: 'solid', 5: 'typography'}, 'attr_6': {0: 'default', 1: 'long sleeves', 2: 'short sleeves'}, 'attr_7': {0: 'cuffed sleeves', 1: 'regular sleeves'}, 'attr_8': {0: 'applique', 1: 'default'}, 'attr_9': {0: 'dummy_value'}, 'attr_10': {0: 'dummy_value'}}
category_mappings_women_tops = {'attr_1': {0: 'black', 1: 'blue', 2: 'default', 3: 'green', 4: 'maroon', 5: 'multicolor', 6: 'navy blue', 7: 'peach', 8: 'pink', 9: 'red', 10: 'white', 11: 'yellow'}, 'attr_2': {0: 'boxy', 1: 'default', 2: 'fitted', 3: 'regular'}, 'attr_3': {0: 'crop', 1: 'regular'}, 'attr_4': {0: 'default', 1: 'high', 2: 'round neck', 3: 'square neck', 4: 'stylised', 5: 'sweetheart neck', 6: 'v-neck'}, 'attr_5': {0: 'casual', 1: 'party'}, 'attr_6': {0: 'default', 1: 'printed', 2: 'solid'}, 'attr_7': {0: 'default', 1: 'floral', 2: 'graphic', 3: 'quirky', 4: 'solid', 5: 'typography'}, 'attr_8': {0: 'long sleeves', 1: 'short sleeves', 2: 'sleeveless', 3: 'three-quarter sleeves'}, 'attr_9': {0: 'default', 1: 'puff sleeves', 2: 'regular sleeves', 3: 'sleeveless'}, 'attr_10': {0: 'applique', 1: 'default', 2: 'knitted', 3: 'ruffles', 4: 'tie-ups', 5: 'waist tie-ups'}}

In [9]:
class MultiOutputModel(nn.Module):
    def __init__(self, base_model, num_classes_per_attribute):
        super(MultiOutputModel, self).__init__()
        self.base_model = base_model
        self.heads = nn.ModuleDict({
            attr: nn.Linear(base_model.config.hidden_size, num_classes)
            for attr, num_classes in num_classes_per_attribute.items()
        })

    def forward(self, pixel_values):
        outputs = self.base_model(pixel_values=pixel_values)
        cls_hidden_states = outputs.last_hidden_state[:, 0, :]
        logits = {attr: self.heads[attr](cls_hidden_states) for attr in self.heads}
        return logits

In [10]:
# Load the base model
base_model = ViTModel.from_pretrained("google/vit-base-patch16-224")

processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")


config.json:   0%|          | 0.00/69.7k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

In [11]:
num_classes_per_attribute_men_tshirts = {
    'attr_1': 4,
    'attr_2': 2,
    'attr_3': 2,
    'attr_4': 3,
    'attr_5': 3,
    'attr_6': 1,
    'attr_7': 1,
    'attr_8': 1,
    'attr_9': 1,
    'attr_10': 1
}

# Create an instance of the MultiOutputModel
model_men = MultiOutputModel(base_model, num_classes_per_attribute_men_tshirts)

# Load the saved state dictionary
model_men.load_state_dict(torch.load('/kaggle/working/model_vit_men_tshirts2.pth'))
model_men.to(device)


  model_men.load_state_dict(torch.load('/kaggle/working/model_vit_men_tshirts2.pth'))


MultiOutputModel(
  (base_model): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_feat

In [12]:
num_classes_per_attribute_kurtis = {
    'attr_1': 13,
    'attr_2': 2,
    'attr_3': 2,
    'attr_4': 2,
    'attr_5': 2,
    'attr_6': 2,
    'attr_7': 2,
    'attr_8': 3,
    'attr_9': 2,
    'attr_10': 1
}

# Create an instance of the MultiOutputModel
model_kurtis = MultiOutputModel(base_model, num_classes_per_attribute_kurtis)

# Load the saved state dictionary
model_kurtis.load_state_dict(torch.load('/kaggle/working/model_vit_kurtis2.pth'))

# Move the model to the appropriate device (CPU or GPU)
model_kurtis.to(device)

  model_kurtis.load_state_dict(torch.load('/kaggle/working/model_vit_kurtis2.pth'))


MultiOutputModel(
  (base_model): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_feat

In [13]:
num_classes_per_attribute_women_tops = {
    'attr_1': 12,
    'attr_2': 4,
    'attr_3': 2,
    'attr_4': 7,
    'attr_5': 2,
    'attr_6': 3,
    'attr_7': 6,
    'attr_8': 4,
    'attr_9': 4,
    'attr_10': 6
}

# Create an instance of the MultiOutputModel
model_women_tops_tunics = MultiOutputModel(base_model, num_classes_per_attribute_women_tops)

# Load the saved state dictionary
model_women_tops_tunics.load_state_dict(torch.load('/kaggle/working/model_vit_women_tops2.pth'))

# Move the model to the appropriate device (CPU or GPU)
model_women_tops_tunics.to(device)

  model_women_tops_tunics.load_state_dict(torch.load('/kaggle/working/model_vit_women_tops2.pth'))


MultiOutputModel(
  (base_model): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_feat

In [14]:
num_classes_per_attribute_women_tshirts = {
    'attr_1': 7,
    'attr_2': 3,
    'attr_3': 3,
    'attr_4': 3,
    'attr_5': 6,
    'attr_6': 3,
    'attr_7': 2,
    'attr_8': 2,
    'attr_9': 1,
    'attr_10': 1
}

# Create an instance of the MultiOutputModel
model_women_tshirts = MultiOutputModel(base_model, num_classes_per_attribute_women_tshirts)

# Load the saved state dictionary
model_women_tshirts.load_state_dict(torch.load('/kaggle/working/model_vit_women_tshirts2.pth'))

# Move the model to the appropriate device (CPU or GPU)
model_women_tshirts.to(device)

  model_women_tshirts.load_state_dict(torch.load('/kaggle/working/model_vit_women_tshirts2.pth'))


MultiOutputModel(
  (base_model): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_feat

In [15]:
num_classes_per_attribute_sarees = {
    'attr_1': 4,  
    'attr_2': 6,
    'attr_3': 3,
    'attr_4': 8,
    'attr_5': 4,
    'attr_6': 3,
    'attr_7': 4,
    'attr_8': 5,
    'attr_9': 9,
    'attr_10': 2
}

# Create an instance of the MultiOutputModel
model_sarees = MultiOutputModel(base_model, num_classes_per_attribute_sarees)

# Load the saved state dictionary
model_sarees.load_state_dict(torch.load('/kaggle/working/model_vit_sarees2.pth'))

# Move the model to the appropriate device (CPU or GPU)
model_sarees.to(device)

  model_sarees.load_state_dict(torch.load('/kaggle/working/model_vit_sarees2.pth'))


MultiOutputModel(
  (base_model): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_feat

In [19]:
from tqdm import tqdm
import contextlib
from IPython.display import HTML
from tensorflow.keras.preprocessing.image import load_img, img_to_array
import pandas as pd
import io

# Mapping of categories to their respective 'lens' values
categories_number = {
    'Men Tshirts': 5, 
    'Sarees': 10, 
    'Kurtis': 9, 
    'Women Tshirts': 8, 
    'Women Tops & Tunics': 10
}

model_selection = {
    'Men Tshirts': model_men,
    'Sarees': model_sarees,
    'Kurtis': model_kurtis,
    'Women Tshirts': model_women_tshirts,
    'Women Tops & Tunics': model_women_tops_tunics
}

mapping_selection = {
    'Men Tshirts': category_mappings_men_tshirts,
    'Sarees': category_mappings_sarees,
    'Kurtis': category_mappings_kurtis,
    'Women Tshirts': category_mappings_women_tshirts,
    'Women Tops & Tunics': category_mappings_women_tops
}

# Initialize an empty list to store the final results
final_results = []

# Iterate over the train DataFrame with tqdm for a single progress bar
for idx, row in tqdm(train_df.iterrows(), total=len(train_df), desc="Processing Images"):
    category = row['Category']
    
    # Suppress output by redirecting stdout temporarily
    with contextlib.redirect_stdout(io.StringIO()):
        image_path = '/kaggle/input/visual-taxonomy/train_images/' + "{:06}".format(row['id']) + '.jpg'
        predictions = predict(image_path, model_selection[category], mapping_selection[category], processor)
    
    # Extract attributes from the predictions
    result = {
        'id': row['id'],
        'Category': category,
        'len': categories_number[category],  # Set 'len' according to the category
        **predictions  # Unpack the dictionary of attributes into the result
    }
    
    final_results.append(result)

# Create the final DataFrame from the results
final_df = pd.DataFrame(final_results)

# Save the final DataFrame to a CSV file
final_df.to_csv('predictions_vit2_train.csv', index=False)

print("Final CSV generated and saved as 'predictions_vit2_train.csv'.")

Processing Images: 100%|██████████| 70213/70213 [32:20<00:00, 36.19it/s]


Final CSV generated and saved as 'predictions_vit2_train.csv'.
