In [None]:
import pandas as pd
import os
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import ViTModel, ViTImageProcessor
from PIL import Image
import numpy as np
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from torch.cuda.amp import GradScaler, autocast
import torchvision.transforms as T
import torch.optim as optim
import torchvision.models as models

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [3]:
train_csv_path = '/kaggle/input/visual-taxonomy/train.csv'
test_csv_path = '/kaggle/input/visual-taxonomy/test.csv'
category_attributes_path = '/kaggle/input/visual-taxonomy/category_attributes.parquet'
sample_submission_path = '/kaggle/input/visual-taxonomy/sample_submission.csv'
images_folder = '/kaggle/input/visual-taxonomy/train_images'

In [4]:
print("\nLoading train.csv ...")
train_df = pd.read_csv(train_csv_path)
print(f"train.csv shape: {train_df.shape}")
print(f"train.csv columns: {train_df.columns}")

train_df.head()


Loading train.csv ...
train.csv shape: (70213, 13)
train.csv columns: Index(['id', 'Category', 'len', 'attr_1', 'attr_2', 'attr_3', 'attr_4',
       'attr_5', 'attr_6', 'attr_7', 'attr_8', 'attr_9', 'attr_10'],
      dtype='object')


Unnamed: 0,id,Category,len,attr_1,attr_2,attr_3,attr_4,attr_5,attr_6,attr_7,attr_8,attr_9,attr_10
0,0,Men Tshirts,5,default,round,printed,default,short sleeves,,,,,
1,1,Men Tshirts,5,multicolor,polo,solid,solid,short sleeves,,,,,
2,2,Men Tshirts,5,default,polo,solid,solid,short sleeves,,,,,
3,3,Men Tshirts,5,multicolor,polo,solid,solid,short sleeves,,,,,
4,4,Men Tshirts,5,multicolor,polo,solid,solid,short sleeves,,,,,


In [5]:
print("\nLoading test.csv ...")
test_df = pd.read_csv(test_csv_path)
print(f"test.csv shape: {test_df.shape}")
print(f"test.csv columns: {test_df.columns}")

test_df.head()



Loading test.csv ...
test.csv shape: (30205, 2)
test.csv columns: Index(['id', 'Category'], dtype='object')


Unnamed: 0,id,Category
0,0,Men Tshirts
1,1,Men Tshirts
2,2,Men Tshirts
3,3,Men Tshirts
4,4,Men Tshirts


In [39]:
category_dfs = {}

for category in train_df['Category'].unique():
    category_dfs[category] = train_df[train_df['Category'] == category]

men_tshirts_df = category_dfs['Men Tshirts']
print(f"Men Tshirts DataFrame:\n{men_tshirts_df.info()}")
print(men_tshirts_df.head())
print(len(men_tshirts_df))

sarees_df = category_dfs['Sarees']
print(f"Sarees DataFrame:\n{sarees_df.info()}")
print(sarees_df.head())
print(len(sarees_df))

kurtis_df = category_dfs['Kurtis']
print(f"Kurtis DataFrame:\n{kurtis_df.info()}")
print(kurtis_df.head())
print(len(kurtis_df))

women_tshirts_df = category_dfs['Women Tshirts']
print(f"Women Tshirts:\n{women_tshirts_df.info()}")
print(women_tshirts_df.head())
print(len(women_tshirts_df))

women_tops_df = category_dfs['Women Tops & Tunics']
print(f"Women Tops and Tunics DataFrame:\n{women_tops_df.info()}")
print(women_tshirts_df.head())
print(len(women_tshirts_df))

<class 'pandas.core.frame.DataFrame'>
Index: 7267 entries, 0 to 7266
Data columns (total 13 columns):
 #   Column    Non-Null Count  Dtype 
---  ------    --------------  ----- 
 0   id        7267 non-null   int64 
 1   Category  7267 non-null   object
 2   len       7267 non-null   int64 
 3   attr_1    6010 non-null   object
 4   attr_2    6144 non-null   object
 5   attr_3    5791 non-null   object
 6   attr_4    5949 non-null   object
 7   attr_5    5977 non-null   object
 8   attr_6    0 non-null      object
 9   attr_7    0 non-null      object
 10  attr_8    0 non-null      object
 11  attr_9    0 non-null      object
 12  attr_10   0 non-null      object
dtypes: int64(2), object(11)
memory usage: 794.8+ KB
Men Tshirts DataFrame:
None
   id     Category  len      attr_1 attr_2   attr_3   attr_4         attr_5  \
0   0  Men Tshirts    5     default  round  printed  default  short sleeves   
1   1  Men Tshirts    5  multicolor   polo    solid    solid  short sleeves   
2   2  Men

In [8]:
# Preprocessing function for images
def preprocess_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, [224, 224])
    return tf.keras.applications.resnet.preprocess_input(image)

In [None]:
sarees_df = sarees_df.drop(columns=['Category', 'len'])
sarees_df = sarees_df.apply(
    lambda col: col.fillna('dummy_value') if col.isna().all() 
    else col.fillna(col.mode()[0]) if col.dtype == 'object' or col.dtype.name == 'category' 
    else col
)

# Attribute columns and category mappings
attribute_columns_sarees = [f'attr_{i}' for i in range(1, 11)]
category_mappings_sarees = {}

for col in attribute_columns_sarees:
    sarees_df[col] = pd.Categorical(sarees_df[col])
    category_mappings_sarees[col] = dict(enumerate(sarees_df[col].cat.categories))
    sarees_df[col] = sarees_df[col].cat.codes

sarees_df['image_path'] = '/kaggle/input/visual-taxonomy/train_images/' + sarees_df['id'].apply(lambda x: f"{x:06d}.jpg")

train_df_sarees, temp_df = train_test_split(sarees_df, test_size=0.3, random_state=42)
val_df_sarees, test_df_sarees = train_test_split(temp_df, test_size=0.5, random_state=42)

# Define number of classes for each attribute
num_classes_per_attribute = {
    'attr_1': 4,  
    'attr_2': 6,
    'attr_3': 3,
    'attr_4': 8,
    'attr_5': 4,
    'attr_6': 3,
    'attr_7': 4,
    'attr_8': 5,
    'attr_9': 9,
    'attr_10': 2
}

class MultiOutputModel(nn.Module):
    def __init__(self, num_classes_per_attribute):
        super(MultiOutputModel, self).__init__()
        
        # Load ResNet101 as the base model
        self.base_model = models.resnet101(pretrained=True)
        
        # Save the original fc's in_features before replacing it
        in_features = self.base_model.fc.in_features
        
        # Replace the fully connected layer with an identity layer
        self.base_model.fc = nn.Identity()
        
        # Add dropout
        self.dropout = nn.Dropout(0.3)
        
        # Create classification heads for each attribute
        self.heads = nn.ModuleDict({
            attr: nn.Linear(in_features, num_classes)
            for attr, num_classes in num_classes_per_attribute.items()
        })

    def forward(self, x):
        # Extract features using the base model
        features = self.base_model(x)
        
        # Apply dropout
        features = self.dropout(features)
        
        # Compute logits for each attribute
        logits = {attr: head(features) for attr, head in self.heads.items()}
        
        return logits


# Initialize the ResNet101-based model
model_resnet101_sarees = MultiOutputModel(num_classes_per_attribute)

# Data Augmentation
data_augmentation = T.Compose([
    T.RandomResizedCrop(224),
    T.RandomHorizontalFlip(),
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Dataset and DataLoader
class CustomDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        id_str = str(self.df.iloc[idx]["id"]).zfill(6)
        path = self.df.iloc[idx]["image_path"]
        image = Image.open(path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        labels = {attr: torch.tensor(self.df.iloc[idx][attr], dtype=torch.long) for attr in attribute_columns_sarees}
        return image, labels

# Initialize datasets and data loaders
train_dataset = CustomDataset(train_df_sarees, transform=data_augmentation)
val_dataset = CustomDataset(val_df_sarees, transform=T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]))
test_dataset = CustomDataset(test_df_sarees, transform=T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]))

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

# Device Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_resnet101_sarees.to(device)

# Optimizer and Learning Rate Scheduler
optimizer = torch.optim.AdamW(model_resnet101_sarees.parameters(), lr=5e-5, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
scaler = GradScaler()

# Training and Validation Loop
num_epochs = 10

for epoch in range(num_epochs):
    # Training
    model_resnet101_sarees.train()
    for images, labels in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}/{num_epochs}", leave=False):
        images = images.to(device)
        optimizer.zero_grad()

        with autocast():
            outputs = model_resnet101_sarees(images)
            losses = {attr: nn.CrossEntropyLoss()(outputs[attr], labels[attr].to(device)) for attr in labels}
            total_loss = sum(losses.values())

        scaler.scale(total_loss).backward()
        scaler.step(optimizer)
        scaler.update()

    scheduler.step()

    # Validation
    model_resnet101_sarees.eval()
    val_corrects = {attr: 0 for attr in attribute_columns_sarees}
    val_samples = {attr: 0 for attr in attribute_columns_sarees}

    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc="Validating", leave=False):
            images = images.to(device)
            outputs = model_resnet101_sarees(images)
            for attr in labels:
                target = labels[attr].to(device)
                _, preds = torch.max(outputs[attr], 1)
                val_corrects[attr] += torch.sum(preds == target).item()
                val_samples[attr] += target.size(0)

    # Validation Accuracy per Attribute
    for attr in attribute_columns_sarees:
        acc = val_corrects[attr] / val_samples[attr] if val_samples[attr] > 0 else 0
        print(f"Epoch {epoch + 1}, Validation Accuracy for {attr}: {acc:.4f}")

# Evaluate the model on the test set
model_resnet101_sarees.eval()
test_corrects = {attr: 0 for attr in num_classes_per_attribute.keys()}
test_samples = {attr: 0 for attr in num_classes_per_attribute.keys()}

with torch.no_grad():
    for images, labels in tqdm(test_loader, desc="Testing", leave=False):
        images = images.to(device)
        
        # Get model predictions
        outputs = model_resnet101_sarees(images)
        
        for attr, target in labels.items():
            # Move target to the device
            target = target.to(device)
            
            # Get predicted class labels
            _, preds = torch.max(outputs[attr], 1)
            
            # Compute accuracy
            test_corrects[attr] += torch.sum(preds == target).item()
            test_samples[attr] += target.size(0)

print("\nTest Accuracy per Attribute:")
for attr in num_classes_per_attribute.keys():
    test_acc = test_corrects[attr] / test_samples[attr] if test_samples[attr] > 0 else 0
    print(f"Attribute {attr}: {test_acc:.4f}")


  scaler = GradScaler()
  self.pid = os.fork()
  with autocast():
  self.pid = os.fork()
                                                                      

Epoch 1, Validation Accuracy for attr_1: 0.8626
Epoch 1, Validation Accuracy for attr_2: 0.6799
Epoch 1, Validation Accuracy for attr_3: 0.8285
Epoch 1, Validation Accuracy for attr_4: 0.5469
Epoch 1, Validation Accuracy for attr_5: 0.6784
Epoch 1, Validation Accuracy for attr_6: 0.9560
Epoch 1, Validation Accuracy for attr_7: 0.7122
Epoch 1, Validation Accuracy for attr_8: 0.8532
Epoch 1, Validation Accuracy for attr_9: 0.6363
Epoch 1, Validation Accuracy for attr_10: 0.8343


                                                                      

Epoch 2, Validation Accuracy for attr_1: 0.8601
Epoch 2, Validation Accuracy for attr_2: 0.6621
Epoch 2, Validation Accuracy for attr_3: 0.8474
Epoch 2, Validation Accuracy for attr_4: 0.5458
Epoch 2, Validation Accuracy for attr_5: 0.6966
Epoch 2, Validation Accuracy for attr_6: 0.9560
Epoch 2, Validation Accuracy for attr_7: 0.7242
Epoch 2, Validation Accuracy for attr_8: 0.8488
Epoch 2, Validation Accuracy for attr_9: 0.6461
Epoch 2, Validation Accuracy for attr_10: 0.8343


                                                                      

Epoch 3, Validation Accuracy for attr_1: 0.8601
Epoch 3, Validation Accuracy for attr_2: 0.6791
Epoch 3, Validation Accuracy for attr_3: 0.8310
Epoch 3, Validation Accuracy for attr_4: 0.5505
Epoch 3, Validation Accuracy for attr_5: 0.7053
Epoch 3, Validation Accuracy for attr_6: 0.9557
Epoch 3, Validation Accuracy for attr_7: 0.7158
Epoch 3, Validation Accuracy for attr_8: 0.8539
Epoch 3, Validation Accuracy for attr_9: 0.6530
Epoch 3, Validation Accuracy for attr_10: 0.8361


                                                                      

Epoch 4, Validation Accuracy for attr_1: 0.8619
Epoch 4, Validation Accuracy for attr_2: 0.6879
Epoch 4, Validation Accuracy for attr_3: 0.8496
Epoch 4, Validation Accuracy for attr_4: 0.5483
Epoch 4, Validation Accuracy for attr_5: 0.7151
Epoch 4, Validation Accuracy for attr_6: 0.9564
Epoch 4, Validation Accuracy for attr_7: 0.7329
Epoch 4, Validation Accuracy for attr_8: 0.8666
Epoch 4, Validation Accuracy for attr_9: 0.6512
Epoch 4, Validation Accuracy for attr_10: 0.8350


                                                                      

Epoch 5, Validation Accuracy for attr_1: 0.8630
Epoch 5, Validation Accuracy for attr_2: 0.6908
Epoch 5, Validation Accuracy for attr_3: 0.8419
Epoch 5, Validation Accuracy for attr_4: 0.5643
Epoch 5, Validation Accuracy for attr_5: 0.7144
Epoch 5, Validation Accuracy for attr_6: 0.9608
Epoch 5, Validation Accuracy for attr_7: 0.7355
Epoch 5, Validation Accuracy for attr_8: 0.8681
Epoch 5, Validation Accuracy for attr_9: 0.6570
Epoch 5, Validation Accuracy for attr_10: 0.8350


                                                                      

Epoch 6, Validation Accuracy for attr_1: 0.8637
Epoch 6, Validation Accuracy for attr_2: 0.6868
Epoch 6, Validation Accuracy for attr_3: 0.8488
Epoch 6, Validation Accuracy for attr_4: 0.5723
Epoch 6, Validation Accuracy for attr_5: 0.7104
Epoch 6, Validation Accuracy for attr_6: 0.9517
Epoch 6, Validation Accuracy for attr_7: 0.7398
Epoch 6, Validation Accuracy for attr_8: 0.8659
Epoch 6, Validation Accuracy for attr_9: 0.6446
Epoch 6, Validation Accuracy for attr_10: 0.8358


                                                                      

Epoch 7, Validation Accuracy for attr_1: 0.8659
Epoch 7, Validation Accuracy for attr_2: 0.6962
Epoch 7, Validation Accuracy for attr_3: 0.8528
Epoch 7, Validation Accuracy for attr_4: 0.5705
Epoch 7, Validation Accuracy for attr_5: 0.7144
Epoch 7, Validation Accuracy for attr_6: 0.9578
Epoch 7, Validation Accuracy for attr_7: 0.7366
Epoch 7, Validation Accuracy for attr_8: 0.8695
Epoch 7, Validation Accuracy for attr_9: 0.6584
Epoch 7, Validation Accuracy for attr_10: 0.8361


                                                                      

Epoch 8, Validation Accuracy for attr_1: 0.8594
Epoch 8, Validation Accuracy for attr_2: 0.6951
Epoch 8, Validation Accuracy for attr_3: 0.8474
Epoch 8, Validation Accuracy for attr_4: 0.5705
Epoch 8, Validation Accuracy for attr_5: 0.7162
Epoch 8, Validation Accuracy for attr_6: 0.9520
Epoch 8, Validation Accuracy for attr_7: 0.7391
Epoch 8, Validation Accuracy for attr_8: 0.8681
Epoch 8, Validation Accuracy for attr_9: 0.6573
Epoch 8, Validation Accuracy for attr_10: 0.8350


                                                                      

Epoch 9, Validation Accuracy for attr_1: 0.8597
Epoch 9, Validation Accuracy for attr_2: 0.6948
Epoch 9, Validation Accuracy for attr_3: 0.8503
Epoch 9, Validation Accuracy for attr_4: 0.5701
Epoch 9, Validation Accuracy for attr_5: 0.7148
Epoch 9, Validation Accuracy for attr_6: 0.9578
Epoch 9, Validation Accuracy for attr_7: 0.7395
Epoch 9, Validation Accuracy for attr_8: 0.8656
Epoch 9, Validation Accuracy for attr_9: 0.6570
Epoch 9, Validation Accuracy for attr_10: 0.8350


                                                                       

Epoch 10, Validation Accuracy for attr_1: 0.8605
Epoch 10, Validation Accuracy for attr_2: 0.6962
Epoch 10, Validation Accuracy for attr_3: 0.8499
Epoch 10, Validation Accuracy for attr_4: 0.5730
Epoch 10, Validation Accuracy for attr_5: 0.7151
Epoch 10, Validation Accuracy for attr_6: 0.9578
Epoch 10, Validation Accuracy for attr_7: 0.7380
Epoch 10, Validation Accuracy for attr_8: 0.8703
Epoch 10, Validation Accuracy for attr_9: 0.6595
Epoch 10, Validation Accuracy for attr_10: 0.8354


                                                        


Test Accuracy per Attribute:
Attribute attr_1: 0.8626
Attribute attr_2: 0.6966
Attribute attr_3: 0.8412
Attribute attr_4: 0.5578
Attribute attr_5: 0.7028
Attribute attr_6: 0.9582
Attribute attr_7: 0.7500
Attribute attr_8: 0.8681
Attribute attr_9: 0.6446
Attribute attr_10: 0.8314




In [62]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report


# Evaluate the model on the test set
model_resnet101_sarees.eval()

# Prepare lists to store predictions and true labels for all attributes
true_labels_all = {attr: [] for attr in num_classes_per_attribute.keys()}
pred_labels_all = {attr: [] for attr in num_classes_per_attribute.keys()}

with torch.no_grad():
    for images, labels in tqdm(test_loader, desc="Testing", leave=False):
        images = images.to(device)
        
        # Get model predictions
        outputs = model_resnet101_sarees(images)
        
        for attr, target in labels.items():
            # Move target to the device
            target = target.to(device)
            
            # Get predicted class labels
            _, preds = torch.max(outputs[attr], 1)
            
            # Store true labels and predicted labels for each attribute
            true_labels_all[attr].extend(target.cpu().numpy())
            pred_labels_all[attr].extend(preds.cpu().numpy())

# Print test accuracy, precision, recall, F1 score, and classification report for each attribute
for attr in num_classes_per_attribute.keys():
    true_labels = true_labels_all[attr]
    pred_labels = pred_labels_all[attr]

    # Compute overall accuracy
    acc = accuracy_score(true_labels, pred_labels)
    
    # Compute precision, recall, and F1 score
    precision = precision_score(true_labels, pred_labels, average='weighted', zero_division=1)
    recall = recall_score(true_labels, pred_labels, average='weighted', zero_division=1)
    f1 = f1_score(true_labels, pred_labels, average='weighted', zero_division=1)
    
    # Print classification report
    print(f"\nClassification Report for {attr}:")
    print(classification_report(true_labels, pred_labels, target_names=category_mappings_sarees[attr].values()))
    
    # Print overall metrics
    print(f"Attribute {attr} - Accuracy: {acc:.4f}")
    print(f"Attribute {attr} - Precision: {precision:.4f}")
    print(f"Attribute {attr} - Recall: {recall:.4f}")
    print(f"Attribute {attr} - F1 Score: {f1:.4f}")

# Calculate overall metrics across all attributes
all_true_labels = []
all_pred_labels = []
for attr in num_classes_per_attribute.keys():
    all_true_labels.extend(true_labels_all[attr])
    all_pred_labels.extend(pred_labels_all[attr])

# Compute overall accuracy, precision, recall, and F1 score
overall_acc = accuracy_score(all_true_labels, all_pred_labels)
overall_precision = precision_score(all_true_labels, all_pred_labels, average='weighted', zero_division=1)
overall_recall = recall_score(all_true_labels, all_pred_labels, average='weighted', zero_division=1)
overall_f1 = f1_score(all_true_labels, all_pred_labels, average='weighted', zero_division=1)

# Print overall metrics
print("\nOverall Metrics:")
print(f"Overall Accuracy: {overall_acc:.4f}")
print(f"Overall Precision: {overall_precision:.4f}")
print(f"Overall Recall: {overall_recall:.4f}")
print(f"Overall F1 Score: {overall_f1:.4f}")


  self.pid = os.fork()
  self.pid = os.fork()
                                                        


Classification Report for attr_1:
                precision    recall  f1-score   support

       default       0.59      0.38      0.46       120
same as border       0.47      0.13      0.20       254
 same as saree       0.88      0.97      0.93      2355
         solid       0.00      0.00      0.00        23

      accuracy                           0.86      2752
     macro avg       0.49      0.37      0.40      2752
  weighted avg       0.82      0.86      0.83      2752

Attribute attr_1 - Accuracy: 0.8626
Attribute attr_1 - Precision: 0.8237
Attribute attr_1 - Recall: 0.8626
Attribute attr_1 - F1 Score: 0.8313

Classification Report for attr_2:
               precision    recall  f1-score   support

      default       0.81      0.47      0.59       120
    no border       0.50      0.53      0.52        43
        solid       0.59      0.73      0.65        62
temple border       0.58      0.80      0.67       200
 woven design       0.63      0.60      0.62       939
     

In [None]:
# Data Preprocessing
kurtis_df = kurtis_df.drop(columns=['Category', 'len'])
kurtis_df = kurtis_df.apply(
    lambda col: col.fillna('dummy_value') if col.isna().all() 
    else col.fillna(col.mode()[0]) if col.dtype == 'object' or col.dtype.name == 'category' 
    else col
)

attribute_columns_kurtis = [f'attr_{i}' for i in range(1, 11)]
category_mappings_kurtis = {}

for col in attribute_columns_kurtis:
    kurtis_df[col] = pd.Categorical(kurtis_df[col])
    category_mappings_kurtis[col] = dict(enumerate(kurtis_df[col].cat.categories))
    kurtis_df[col] = kurtis_df[col].cat.codes

kurtis_df['image_path'] = '/kaggle/input/visual-taxonomy/train_images/' + kurtis_df['id'].apply(lambda x: f"{x:06d}.jpg")

# Split dataset
train_df_kurtis, temp_df_kurtis = train_test_split(kurtis_df, test_size=0.3, random_state=42)
val_df_kurtis, test_df_kurtis = train_test_split(temp_df_kurtis, test_size=0.5, random_state=42)

# Define number of classes for each attribute
num_classes_per_attribute_kurtis = {
    'attr_1': 13, 'attr_2': 2, 'attr_3': 2, 'attr_4': 2, 'attr_5': 2,
    'attr_6': 2, 'attr_7': 2, 'attr_8': 3, 'attr_9': 2, 'attr_10': 1
}

# Custom Dataset
class CustomDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        path = self.df.iloc[idx]["image_path"]
        image = Image.open(path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        labels = {}
        for attr in attribute_columns_kurtis:
            labels[attr] = torch.tensor(self.df.iloc[idx][attr], dtype=torch.long)

        return image, labels

# Data Augmentation and Loaders
data_augmentation = T.Compose([
    T.RandomResizedCrop(224),
    T.RandomHorizontalFlip(),
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_test_transforms = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset_kurtis = CustomDataset(train_df_kurtis, transform=data_augmentation)
val_dataset_kurtis = CustomDataset(val_df_kurtis, transform=val_test_transforms)
test_dataset_kurtis = CustomDataset(test_df_kurtis, transform=val_test_transforms)

train_loader_kurtis = DataLoader(train_dataset_kurtis, batch_size=32, shuffle=True, num_workers=4)
val_loader_kurtis = DataLoader(val_dataset_kurtis, batch_size=32, shuffle=False, num_workers=4)
test_loader_kurtis = DataLoader(test_dataset_kurtis, batch_size=32, shuffle=False, num_workers=4)

# Model Definition
class MultiOutputResNet(nn.Module):
    def __init__(self, num_classes_per_attribute):
        super(MultiOutputResNet, self).__init__()
        base_model = resnet101(pretrained=True)
        self.base = nn.Sequential(*list(base_model.children())[:-2])
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(0.3)
        self.heads = nn.ModuleDict({
            attr: nn.Linear(2048, num_classes)
            for attr, num_classes in num_classes_per_attribute.items()
        })

    def forward(self, x):
        x = self.base(x)
        x = self.pool(x)
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        logits = {attr: self.heads[attr](x) for attr in self.heads}
        return logits

model_kurtis = MultiOutputResNet(num_classes_per_attribute_kurtis)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_kurtis.to(device)

# Optimizer, Scheduler, and Loss
optimizer = optim.AdamW(model_kurtis.parameters(), lr=5e-5, weight_decay=0.01)
scheduler = CosineAnnealingLR(optimizer, T_max=10)
loss_fn = nn.CrossEntropyLoss()

# Training Loop
num_epochs = 10
for epoch in range(num_epochs):
    model_kurtis.train()
    for images, labels in tqdm(train_loader_kurtis, desc=f"Training Epoch {epoch + 1}/{num_epochs}", leave=False):
        images = images.to(device)
        optimizer.zero_grad()

        outputs = model_kurtis(images)
        total_loss = 0
        for attr in labels:
            target = labels[attr].to(device)
            loss = loss_fn(outputs[attr], target)
            total_loss += loss

        total_loss.backward()
        optimizer.step()

    scheduler.step()

    # Validation
    model_kurtis.eval()
    val_corrects = {attr: 0 for attr in attribute_columns_kurtis}
    val_samples = {attr: 0 for attr in attribute_columns_kurtis}

    with torch.no_grad():
        for images, labels in val_loader_kurtis:
            images = images.to(device)
            outputs = model_kurtis(images)

            for attr in labels:
                target = labels[attr].to(device)
                _, preds = torch.max(outputs[attr], 1)
                val_corrects[attr] += (preds == target).sum().item()
                val_samples[attr] += target.size(0)

    for attr in attribute_columns_kurtis:
        acc = val_corrects[attr] / val_samples[attr] if val_samples[attr] > 0 else 0
        print(f"Validation Accuracy for {attr}: {acc:.4f}")

# Testing
model_kurtis.eval()
test_corrects = {attr: 0 for attr in attribute_columns_kurtis}
test_samples = {attr: 0 for attr in attribute_columns_kurtis}

with torch.no_grad():
    for images, labels in test_loader_kurtis:
        images = images.to(device)
        outputs = model_kurtis(images)

        for attr in labels:
            target = labels[attr].to(device)
            _, preds = torch.max(outputs[attr], 1)
            test_corrects[attr] += (preds == target).sum().item()
            test_samples[attr] += target.size(0)

for attr in attribute_columns_kurtis:
    acc = test_corrects[attr] / test_samples[attr] if test_samples[attr] > 0 else 0
    print(f"Test Accuracy for {attr}: {acc:.4f}")


  self.pid = os.fork()
  self.pid = os.fork()
                                                                      

Validation Accuracy for attr_1: 0.6862
Validation Accuracy for attr_2: 0.8788
Validation Accuracy for attr_3: 0.8192
Validation Accuracy for attr_4: 0.9394
Validation Accuracy for attr_5: 0.9022
Validation Accuracy for attr_6: 0.7605
Validation Accuracy for attr_7: 0.7370
Validation Accuracy for attr_8: 0.9159
Validation Accuracy for attr_9: 0.9736
Validation Accuracy for attr_10: 1.0000


                                                                      

Validation Accuracy for attr_1: 0.7322
Validation Accuracy for attr_2: 0.8925
Validation Accuracy for attr_3: 0.8143
Validation Accuracy for attr_4: 0.9394
Validation Accuracy for attr_5: 0.9081
Validation Accuracy for attr_6: 0.7507
Validation Accuracy for attr_7: 0.7214
Validation Accuracy for attr_8: 0.9413
Validation Accuracy for attr_9: 0.9932
Validation Accuracy for attr_10: 1.0000


                                                                      

Validation Accuracy for attr_1: 0.7713
Validation Accuracy for attr_2: 0.9013
Validation Accuracy for attr_3: 0.8270
Validation Accuracy for attr_4: 0.9443
Validation Accuracy for attr_5: 0.9071
Validation Accuracy for attr_6: 0.7595
Validation Accuracy for attr_7: 0.7419
Validation Accuracy for attr_8: 0.9541
Validation Accuracy for attr_9: 0.9971
Validation Accuracy for attr_10: 1.0000


                                                                      

Validation Accuracy for attr_1: 0.7830
Validation Accuracy for attr_2: 0.8895
Validation Accuracy for attr_3: 0.8123
Validation Accuracy for attr_4: 0.9423
Validation Accuracy for attr_5: 0.9120
Validation Accuracy for attr_6: 0.7644
Validation Accuracy for attr_7: 0.7674
Validation Accuracy for attr_8: 0.9492
Validation Accuracy for attr_9: 0.9932
Validation Accuracy for attr_10: 1.0000


                                                                      

Validation Accuracy for attr_1: 0.8182
Validation Accuracy for attr_2: 0.9042
Validation Accuracy for attr_3: 0.8231
Validation Accuracy for attr_4: 0.9384
Validation Accuracy for attr_5: 0.9169
Validation Accuracy for attr_6: 0.7722
Validation Accuracy for attr_7: 0.7527
Validation Accuracy for attr_8: 0.9580
Validation Accuracy for attr_9: 0.9941
Validation Accuracy for attr_10: 1.0000


                                                                      

Validation Accuracy for attr_1: 0.8133
Validation Accuracy for attr_2: 0.9062
Validation Accuracy for attr_3: 0.8182
Validation Accuracy for attr_4: 0.9443
Validation Accuracy for attr_5: 0.9238
Validation Accuracy for attr_6: 0.7693
Validation Accuracy for attr_7: 0.7625
Validation Accuracy for attr_8: 0.9609
Validation Accuracy for attr_9: 0.9961
Validation Accuracy for attr_10: 1.0000


                                                                      

Validation Accuracy for attr_1: 0.8231
Validation Accuracy for attr_2: 0.8993
Validation Accuracy for attr_3: 0.8299
Validation Accuracy for attr_4: 0.9394
Validation Accuracy for attr_5: 0.9198
Validation Accuracy for attr_6: 0.7732
Validation Accuracy for attr_7: 0.7664
Validation Accuracy for attr_8: 0.9589
Validation Accuracy for attr_9: 0.9951
Validation Accuracy for attr_10: 1.0000


                                                                      

Validation Accuracy for attr_1: 0.8240
Validation Accuracy for attr_2: 0.9013
Validation Accuracy for attr_3: 0.8260
Validation Accuracy for attr_4: 0.9423
Validation Accuracy for attr_5: 0.9218
Validation Accuracy for attr_6: 0.7732
Validation Accuracy for attr_7: 0.7742
Validation Accuracy for attr_8: 0.9570
Validation Accuracy for attr_9: 0.9961
Validation Accuracy for attr_10: 1.0000


                                                                      

Validation Accuracy for attr_1: 0.8299
Validation Accuracy for attr_2: 0.9081
Validation Accuracy for attr_3: 0.8358
Validation Accuracy for attr_4: 0.9404
Validation Accuracy for attr_5: 0.9179
Validation Accuracy for attr_6: 0.7810
Validation Accuracy for attr_7: 0.7693
Validation Accuracy for attr_8: 0.9589
Validation Accuracy for attr_9: 0.9971
Validation Accuracy for attr_10: 1.0000


                                                                       

Validation Accuracy for attr_1: 0.8309
Validation Accuracy for attr_2: 0.9101
Validation Accuracy for attr_3: 0.8319
Validation Accuracy for attr_4: 0.9433
Validation Accuracy for attr_5: 0.9238
Validation Accuracy for attr_6: 0.7791
Validation Accuracy for attr_7: 0.7654
Validation Accuracy for attr_8: 0.9580
Validation Accuracy for attr_9: 0.9961
Validation Accuracy for attr_10: 1.0000
Test Accuracy for attr_1: 0.7891
Test Accuracy for attr_2: 0.8926
Test Accuracy for attr_3: 0.8154
Test Accuracy for attr_4: 0.9355
Test Accuracy for attr_5: 0.9355
Test Accuracy for attr_6: 0.7734
Test Accuracy for attr_7: 0.7725
Test Accuracy for attr_8: 0.9629
Test Accuracy for attr_9: 0.9932
Test Accuracy for attr_10: 1.0000


In [None]:
# Prepare and process data
men_tshirts_df = men_tshirts_df.drop(columns=['Category', 'len'])
men_tshirts_df = men_tshirts_df.apply(
    lambda col: col.fillna('dummy_value') if col.isna().all() 
    else col.fillna(col.mode()[0]) if col.dtype == 'object' or col.dtype.name == 'category' 
    else col
)

# Define attribute columns and category mappings
attribute_columns_men_tshirts = [f'attr_{i}' for i in range(1, 11)]
category_mappings_men_tshirts = {}

for col in attribute_columns_men_tshirts:
    men_tshirts_df[col] = pd.Categorical(men_tshirts_df[col])
    category_mappings_men_tshirts[col] = dict(enumerate(men_tshirts_df[col].cat.categories))
    men_tshirts_df[col] = men_tshirts_df[col].cat.codes

men_tshirts_df['image_path'] = '/kaggle/input/visual-taxonomy/train_images/' + men_tshirts_df['id'].apply(lambda x: f"{x:06d}.jpg")

# Split dataset into train, validation, and test sets
train_df_men_tshirts, temp_df = train_test_split(men_tshirts_df, test_size=0.3, random_state=42)
val_df_men_tshirts, test_df_men_tshirts = train_test_split(temp_df, test_size=0.5, random_state=42)

# Define number of classes for each attribute
num_classes_per_attribute = {
    'attr_1': 4,
    'attr_2': 2,
    'attr_3': 2,
    'attr_4': 3,
    'attr_5': 3,
    'attr_6': 1,
    'attr_7': 1,
    'attr_8': 1,
    'attr_9': 1,
    'attr_10': 1
}

class MultiOutputModel(nn.Module):
    def __init__(self, num_classes_per_attribute):
        super(MultiOutputModel, self).__init__()
        
        # Load ResNet101 as the base model
        self.base_model = models.resnet101(pretrained=True)
        
        # Save the original fc's in_features before replacing it
        in_features = self.base_model.fc.in_features
        
        # Replace the fully connected layer with an identity layer
        self.base_model.fc = nn.Identity()
        
        # Add dropout
        self.dropout = nn.Dropout(0.3)
        
        # Create classification heads for each attribute
        self.heads = nn.ModuleDict({
            attr: nn.Linear(in_features, num_classes)
            for attr, num_classes in num_classes_per_attribute.items()
        })

    def forward(self, x):
        # Extract features using the base model
        features = self.base_model(x)
        
        # Apply dropout
        features = self.dropout(features)
        
        # Compute logits for each attribute
        logits = {attr: head(features) for attr, head in self.heads.items()}
        
        return logits


# Initialize the ResNet101-based model
model_resnet101 = MultiOutputModel(num_classes_per_attribute)

# Data Augmentation
data_augmentation = T.Compose([
    T.RandomResizedCrop(224),
    T.RandomHorizontalFlip(),
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Dataset and DataLoader
class CustomDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        id_str = str(self.df.iloc[idx]["id"]).zfill(6)
        path = self.df.iloc[idx]["image_path"]
        image = Image.open(path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        labels = {attr: torch.tensor(self.df.iloc[idx][attr], dtype=torch.long) for attr in attribute_columns_men_tshirts}
        return image, labels

# Initialize datasets and data loaders
train_dataset = CustomDataset(train_df_men_tshirts, transform=data_augmentation)
val_dataset = CustomDataset(val_df_men_tshirts, transform=T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]))
test_dataset = CustomDataset(test_df_men_tshirts, transform=T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]))

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

# Device Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_resnet101.to(device)

# Optimizer and Learning Rate Scheduler
optimizer = torch.optim.AdamW(model_resnet101.parameters(), lr=5e-5, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
scaler = GradScaler()

# Training and Validation Loop
num_epochs = 10

for epoch in range(num_epochs):
    # Training
    model_resnet101.train()
    for images, labels in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}/{num_epochs}", leave=False):
        images = images.to(device)
        optimizer.zero_grad()

        with autocast():
            outputs = model_resnet101(images)
            losses = {attr: nn.CrossEntropyLoss()(outputs[attr], labels[attr].to(device)) for attr in labels}
            total_loss = sum(losses.values())

        scaler.scale(total_loss).backward()
        scaler.step(optimizer)
        scaler.update()

    scheduler.step()

    # Validation
    model_resnet101.eval()
    val_corrects = {attr: 0 for attr in attribute_columns_men_tshirts}
    val_samples = {attr: 0 for attr in attribute_columns_men_tshirts}

    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc="Validating", leave=False):
            images = images.to(device)
            outputs = model_resnet101(images)
            for attr in labels:
                target = labels[attr].to(device)
                _, preds = torch.max(outputs[attr], 1)
                val_corrects[attr] += torch.sum(preds == target).item()
                val_samples[attr] += target.size(0)

    # Validation Accuracy per Attribute
    for attr in attribute_columns_men_tshirts:
        acc = val_corrects[attr] / val_samples[attr] if val_samples[attr] > 0 else 0
        print(f"Epoch {epoch + 1}, Validation Accuracy for {attr}: {acc:.4f}")

# Evaluate the model on the test set
model_resnet101.eval()
test_corrects = {attr: 0 for attr in num_classes_per_attribute.keys()}
test_samples = {attr: 0 for attr in num_classes_per_attribute.keys()}

with torch.no_grad():
    for images, labels in tqdm(test_loader, desc="Testing", leave=False):
        images = images.to(device)
        
        # Get model predictions
        outputs = model_resnet101(images)
        
        for attr, target in labels.items():
            # Move target to the device
            target = target.to(device)
            
            # Get predicted class labels
            _, preds = torch.max(outputs[attr], 1)
            
            # Compute accuracy
            test_corrects[attr] += torch.sum(preds == target).item()
            test_samples[attr] += target.size(0)

# Print test accuracy for each attribute
print("\nTest Accuracy per Attribute:")
for attr in num_classes_per_attribute.keys():
    test_acc = test_corrects[attr] / test_samples[attr] if test_samples[attr] > 0 else 0
    print(f"Attribute {attr}: {test_acc:.4f}")


  scaler = GradScaler()
  self.pid = os.fork()
  with autocast():
  self.pid = os.fork()
                                                                      

Epoch 1, Validation Accuracy for attr_1: 0.6257
Epoch 1, Validation Accuracy for attr_2: 0.8761
Epoch 1, Validation Accuracy for attr_3: 0.8321
Epoch 1, Validation Accuracy for attr_4: 0.7670
Epoch 1, Validation Accuracy for attr_5: 0.9578
Epoch 1, Validation Accuracy for attr_6: 1.0000
Epoch 1, Validation Accuracy for attr_7: 1.0000
Epoch 1, Validation Accuracy for attr_8: 1.0000
Epoch 1, Validation Accuracy for attr_9: 1.0000
Epoch 1, Validation Accuracy for attr_10: 1.0000


                                                                      

Epoch 2, Validation Accuracy for attr_1: 0.6367
Epoch 2, Validation Accuracy for attr_2: 0.8899
Epoch 2, Validation Accuracy for attr_3: 0.8486
Epoch 2, Validation Accuracy for attr_4: 0.7963
Epoch 2, Validation Accuracy for attr_5: 0.9716
Epoch 2, Validation Accuracy for attr_6: 1.0000
Epoch 2, Validation Accuracy for attr_7: 1.0000
Epoch 2, Validation Accuracy for attr_8: 1.0000
Epoch 2, Validation Accuracy for attr_9: 1.0000
Epoch 2, Validation Accuracy for attr_10: 1.0000


                                                                      

Epoch 3, Validation Accuracy for attr_1: 0.6514
Epoch 3, Validation Accuracy for attr_2: 0.8844
Epoch 3, Validation Accuracy for attr_3: 0.8450
Epoch 3, Validation Accuracy for attr_4: 0.7853
Epoch 3, Validation Accuracy for attr_5: 0.9725
Epoch 3, Validation Accuracy for attr_6: 1.0000
Epoch 3, Validation Accuracy for attr_7: 1.0000
Epoch 3, Validation Accuracy for attr_8: 1.0000
Epoch 3, Validation Accuracy for attr_9: 1.0000
Epoch 3, Validation Accuracy for attr_10: 1.0000


                                                                      

Epoch 4, Validation Accuracy for attr_1: 0.6321
Epoch 4, Validation Accuracy for attr_2: 0.8734
Epoch 4, Validation Accuracy for attr_3: 0.8532
Epoch 4, Validation Accuracy for attr_4: 0.7679
Epoch 4, Validation Accuracy for attr_5: 0.9734
Epoch 4, Validation Accuracy for attr_6: 1.0000
Epoch 4, Validation Accuracy for attr_7: 1.0000
Epoch 4, Validation Accuracy for attr_8: 1.0000
Epoch 4, Validation Accuracy for attr_9: 1.0000
Epoch 4, Validation Accuracy for attr_10: 1.0000


                                                                      

Epoch 5, Validation Accuracy for attr_1: 0.6541
Epoch 5, Validation Accuracy for attr_2: 0.8844
Epoch 5, Validation Accuracy for attr_3: 0.8413
Epoch 5, Validation Accuracy for attr_4: 0.7890
Epoch 5, Validation Accuracy for attr_5: 0.9697
Epoch 5, Validation Accuracy for attr_6: 1.0000
Epoch 5, Validation Accuracy for attr_7: 1.0000
Epoch 5, Validation Accuracy for attr_8: 1.0000
Epoch 5, Validation Accuracy for attr_9: 1.0000
Epoch 5, Validation Accuracy for attr_10: 1.0000


                                                                      

Epoch 6, Validation Accuracy for attr_1: 0.6440
Epoch 6, Validation Accuracy for attr_2: 0.8789
Epoch 6, Validation Accuracy for attr_3: 0.8404
Epoch 6, Validation Accuracy for attr_4: 0.7615
Epoch 6, Validation Accuracy for attr_5: 0.9679
Epoch 6, Validation Accuracy for attr_6: 1.0000
Epoch 6, Validation Accuracy for attr_7: 1.0000
Epoch 6, Validation Accuracy for attr_8: 1.0000
Epoch 6, Validation Accuracy for attr_9: 1.0000
Epoch 6, Validation Accuracy for attr_10: 1.0000


                                                                      

Epoch 7, Validation Accuracy for attr_1: 0.6550
Epoch 7, Validation Accuracy for attr_2: 0.8780
Epoch 7, Validation Accuracy for attr_3: 0.8422
Epoch 7, Validation Accuracy for attr_4: 0.7826
Epoch 7, Validation Accuracy for attr_5: 0.9661
Epoch 7, Validation Accuracy for attr_6: 1.0000
Epoch 7, Validation Accuracy for attr_7: 1.0000
Epoch 7, Validation Accuracy for attr_8: 1.0000
Epoch 7, Validation Accuracy for attr_9: 1.0000
Epoch 7, Validation Accuracy for attr_10: 1.0000


                                                                      

Epoch 8, Validation Accuracy for attr_1: 0.6633
Epoch 8, Validation Accuracy for attr_2: 0.8661
Epoch 8, Validation Accuracy for attr_3: 0.8440
Epoch 8, Validation Accuracy for attr_4: 0.7688
Epoch 8, Validation Accuracy for attr_5: 0.9679
Epoch 8, Validation Accuracy for attr_6: 1.0000
Epoch 8, Validation Accuracy for attr_7: 1.0000
Epoch 8, Validation Accuracy for attr_8: 1.0000
Epoch 8, Validation Accuracy for attr_9: 1.0000
Epoch 8, Validation Accuracy for attr_10: 1.0000


                                                                      

Epoch 9, Validation Accuracy for attr_1: 0.6606
Epoch 9, Validation Accuracy for attr_2: 0.8697
Epoch 9, Validation Accuracy for attr_3: 0.8385
Epoch 9, Validation Accuracy for attr_4: 0.7771
Epoch 9, Validation Accuracy for attr_5: 0.9688
Epoch 9, Validation Accuracy for attr_6: 1.0000
Epoch 9, Validation Accuracy for attr_7: 1.0000
Epoch 9, Validation Accuracy for attr_8: 1.0000
Epoch 9, Validation Accuracy for attr_9: 1.0000
Epoch 9, Validation Accuracy for attr_10: 1.0000


                                                                       

Epoch 10, Validation Accuracy for attr_1: 0.6624
Epoch 10, Validation Accuracy for attr_2: 0.8661
Epoch 10, Validation Accuracy for attr_3: 0.8450
Epoch 10, Validation Accuracy for attr_4: 0.7771
Epoch 10, Validation Accuracy for attr_5: 0.9688
Epoch 10, Validation Accuracy for attr_6: 1.0000
Epoch 10, Validation Accuracy for attr_7: 1.0000
Epoch 10, Validation Accuracy for attr_8: 1.0000
Epoch 10, Validation Accuracy for attr_9: 1.0000
Epoch 10, Validation Accuracy for attr_10: 1.0000


                                                        


Test Accuracy per Attribute:
Attribute attr_1: 0.6664
Attribute attr_2: 0.8698
Attribute attr_3: 0.8368
Attribute attr_4: 0.7938
Attribute attr_5: 0.9578
Attribute attr_6: 1.0000
Attribute attr_7: 1.0000
Attribute attr_8: 1.0000
Attribute attr_9: 1.0000
Attribute attr_10: 1.0000




In [None]:
# Prepare and process data
women_tops_df = women_tops_df.drop(columns=['Category', 'len'])
women_tops_df = women_tops_df.apply(
    lambda col: col.fillna('dummy_value') if col.isna().all() 
    else col.fillna(col.mode()[0]) if col.dtype == 'object' or col.dtype.name == 'category' 
    else col
)

# Define attribute columns and category mappings
attribute_columns_women_tops = [f'attr_{i}' for i in range(1, 11)]
category_mappings_women_tops = {}

for col in attribute_columns_women_tops:
    women_tops_df[col] = pd.Categorical(women_tops_df[col])
    category_mappings_women_tops[col] = dict(enumerate(women_tops_df[col].cat.categories))
    women_tops_df[col] = women_tops_df[col].cat.codes

women_tops_df['image_path'] = '/kaggle/input/visual-taxonomy/train_images/' + women_tops_df['id'].apply(lambda x: f"{x:06d}.jpg")

# Split dataset into train, validation, and test sets
train_df_women_tops, temp_df = train_test_split(women_tops_df, test_size=0.3, random_state=42)
val_df_women_tops, test_df_women_tops = train_test_split(temp_df, test_size=0.5, random_state=42)

# Define number of classes for each attribute
num_classes_per_attribute_women_tops = {
    'attr_1': 12,
    'attr_2': 4,
    'attr_3': 2,
    'attr_4': 7,
    'attr_5': 2,
    'attr_6': 3,
    'attr_7': 6,
    'attr_8': 4,
    'attr_9': 4,
    'attr_10': 6
}

class MultiOutputModel(nn.Module):
    def __init__(self, num_classes_per_attribute):
        super(MultiOutputModel, self).__init__()
        
        # Load ResNet101 as the base model
        self.base_model = models.resnet101(pretrained=True)
        
        # Save the original fc's in_features before replacing it
        in_features = self.base_model.fc.in_features
        
        # Replace the fully connected layer with an identity layer
        self.base_model.fc = nn.Identity()
        
        # Add dropout
        self.dropout = nn.Dropout(0.3)
        
        # Create classification heads for each attribute
        self.heads = nn.ModuleDict({
            attr: nn.Linear(in_features, num_classes)
            for attr, num_classes in num_classes_per_attribute.items()
        })

    def forward(self, x):
        # Extract features using the base model
        features = self.base_model(x)
        
        # Apply dropout
        features = self.dropout(features)
        
        # Compute logits for each attribute
        logits = {attr: head(features) for attr, head in self.heads.items()}
        
        return logits


# Initialize the ResNet101-based model
model_resnet101_women_tops = MultiOutputModel(num_classes_per_attribute_women_tops)

# Data Augmentation
data_augmentation = T.Compose([
    T.RandomResizedCrop(224),
    T.RandomHorizontalFlip(),
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Dataset and DataLoader
class CustomDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        id_str = str(self.df.iloc[idx]["id"]).zfill(6)
        path = self.df.iloc[idx]["image_path"]
        image = Image.open(path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        labels = {attr: torch.tensor(self.df.iloc[idx][attr], dtype=torch.long) for attr in attribute_columns_women_tops}
        return image, labels

# Initialize datasets and data loaders
train_dataset_women_tops = CustomDataset(train_df_women_tops, transform=data_augmentation)
val_dataset_women_tops = CustomDataset(val_df_women_tops, transform=T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]))
test_dataset_women_tops = CustomDataset(test_df_women_tops, transform=T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]))

train_loader_women_tops = DataLoader(train_dataset_women_tops, batch_size=32, shuffle=True, num_workers=4)
val_loader_women_tops = DataLoader(val_dataset_women_tops, batch_size=32, shuffle=False, num_workers=4)
test_loader_women_tops = DataLoader(test_dataset_women_tops, batch_size=32, shuffle=False, num_workers=4)

# Device Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_resnet101_women_tops.to(device)

# Optimizer and Learning Rate Scheduler
optimizer = torch.optim.AdamW(model_resnet101_women_tops.parameters(), lr=5e-5, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
scaler = GradScaler()

# Training and Validation Loop
num_epochs = 10

for epoch in range(num_epochs):
    # Training
    model_resnet101_women_tops.train()
    for images, labels in tqdm(train_loader_women_tops, desc=f"Training Epoch {epoch + 1}/{num_epochs}", leave=False):
        images = images.to(device)
        optimizer.zero_grad()

        with autocast():
            outputs = model_resnet101_women_tops(images)
            losses = {attr: nn.CrossEntropyLoss()(outputs[attr], labels[attr].to(device)) for attr in labels}
            total_loss = sum(losses.values())

        scaler.scale(total_loss).backward()
        scaler.step(optimizer)
        scaler.update()

    scheduler.step()

    # Validation
    model_resnet101_women_tops.eval()
    val_corrects = {attr: 0 for attr in attribute_columns_women_tops}
    val_samples = {attr: 0 for attr in attribute_columns_women_tops}

    with torch.no_grad():
        for images, labels in tqdm(val_loader_women_tops, desc="Validating", leave=False):
            images = images.to(device)
            outputs = model_resnet101_women_tops(images)
            for attr in labels:
                target = labels[attr].to(device)
                _, preds = torch.max(outputs[attr], 1)
                val_corrects[attr] += torch.sum(preds == target).item()
                val_samples[attr] += target.size(0)

    # Validation Accuracy per Attribute
    for attr in attribute_columns_women_tops:
        acc = val_corrects[attr] / val_samples[attr] if val_samples[attr] > 0 else 0
        print(f"Epoch {epoch + 1}, Validation Accuracy for {attr}: {acc:.4f}")

# Evaluate the model on the test set
model_resnet101_women_tops.eval()
test_corrects = {attr: 0 for attr in num_classes_per_attribute_women_tops.keys()}
test_samples = {attr: 0 for attr in num_classes_per_attribute_women_tops.keys()}

with torch.no_grad():
    for images, labels in tqdm(test_loader_women_tops, desc="Testing", leave=False):
        images = images.to(device)
        
        # Get model predictions
        outputs = model_resnet101_women_tops(images)
        
        for attr, target in labels.items():
            # Move target to the device
            target = target.to(device)
            
            # Get predicted class labels
            _, preds = torch.max(outputs[attr], 1)
            
            # Compute accuracy
            test_corrects[attr] += torch.sum(preds == target).item()
            test_samples[attr] += target.size(0)

# Print test accuracy for each attribute
print("\nTest Accuracy per Attribute:")
for attr in num_classes_per_attribute_women_tops.keys():
    test_acc = test_corrects[attr] / test_samples[attr] if test_samples[attr] > 0 else 0
    print(f"Attribute {attr}: {test_acc:.4f}")


  scaler = GradScaler()
  self.pid = os.fork()
  with autocast():
  self.pid = os.fork()
                                                                      

Epoch 1, Validation Accuracy for attr_1: 0.5826
Epoch 1, Validation Accuracy for attr_2: 0.7320
Epoch 1, Validation Accuracy for attr_3: 0.8492
Epoch 1, Validation Accuracy for attr_4: 0.6777
Epoch 1, Validation Accuracy for attr_5: 0.9902
Epoch 1, Validation Accuracy for attr_6: 0.8913
Epoch 1, Validation Accuracy for attr_7: 0.8285
Epoch 1, Validation Accuracy for attr_8: 0.7973
Epoch 1, Validation Accuracy for attr_9: 0.8095
Epoch 1, Validation Accuracy for attr_10: 0.8141


                                                                      

Epoch 2, Validation Accuracy for attr_1: 0.6359
Epoch 2, Validation Accuracy for attr_2: 0.7527
Epoch 2, Validation Accuracy for attr_3: 0.8446
Epoch 2, Validation Accuracy for attr_4: 0.6980
Epoch 2, Validation Accuracy for attr_5: 0.9902
Epoch 2, Validation Accuracy for attr_6: 0.8846
Epoch 2, Validation Accuracy for attr_7: 0.8285
Epoch 2, Validation Accuracy for attr_8: 0.8120
Epoch 2, Validation Accuracy for attr_9: 0.8257
Epoch 2, Validation Accuracy for attr_10: 0.8155


                                                                      

Epoch 3, Validation Accuracy for attr_1: 0.6328
Epoch 3, Validation Accuracy for attr_2: 0.7468
Epoch 3, Validation Accuracy for attr_3: 0.8506
Epoch 3, Validation Accuracy for attr_4: 0.7085
Epoch 3, Validation Accuracy for attr_5: 0.9902
Epoch 3, Validation Accuracy for attr_6: 0.8895
Epoch 3, Validation Accuracy for attr_7: 0.8369
Epoch 3, Validation Accuracy for attr_8: 0.8173
Epoch 3, Validation Accuracy for attr_9: 0.8334
Epoch 3, Validation Accuracy for attr_10: 0.8162


                                                                      

Epoch 4, Validation Accuracy for attr_1: 0.6629
Epoch 4, Validation Accuracy for attr_2: 0.7573
Epoch 4, Validation Accuracy for attr_3: 0.8513
Epoch 4, Validation Accuracy for attr_4: 0.7236
Epoch 4, Validation Accuracy for attr_5: 0.9902
Epoch 4, Validation Accuracy for attr_6: 0.8990
Epoch 4, Validation Accuracy for attr_7: 0.8373
Epoch 4, Validation Accuracy for attr_8: 0.8246
Epoch 4, Validation Accuracy for attr_9: 0.8351
Epoch 4, Validation Accuracy for attr_10: 0.8180


                                                                      

Epoch 5, Validation Accuracy for attr_1: 0.6731
Epoch 5, Validation Accuracy for attr_2: 0.7636
Epoch 5, Validation Accuracy for attr_3: 0.8530
Epoch 5, Validation Accuracy for attr_4: 0.7197
Epoch 5, Validation Accuracy for attr_5: 0.9902
Epoch 5, Validation Accuracy for attr_6: 0.9004
Epoch 5, Validation Accuracy for attr_7: 0.8418
Epoch 5, Validation Accuracy for attr_8: 0.8278
Epoch 5, Validation Accuracy for attr_9: 0.8355
Epoch 5, Validation Accuracy for attr_10: 0.8204


                                                                      

Epoch 6, Validation Accuracy for attr_1: 0.6749
Epoch 6, Validation Accuracy for attr_2: 0.7576
Epoch 6, Validation Accuracy for attr_3: 0.8590
Epoch 6, Validation Accuracy for attr_4: 0.7219
Epoch 6, Validation Accuracy for attr_5: 0.9902
Epoch 6, Validation Accuracy for attr_6: 0.9011
Epoch 6, Validation Accuracy for attr_7: 0.8457
Epoch 6, Validation Accuracy for attr_8: 0.8299
Epoch 6, Validation Accuracy for attr_9: 0.8478
Epoch 6, Validation Accuracy for attr_10: 0.8204


                                                                      

Epoch 7, Validation Accuracy for attr_1: 0.6770
Epoch 7, Validation Accuracy for attr_2: 0.7664
Epoch 7, Validation Accuracy for attr_3: 0.8597
Epoch 7, Validation Accuracy for attr_4: 0.7320
Epoch 7, Validation Accuracy for attr_5: 0.9902
Epoch 7, Validation Accuracy for attr_6: 0.9000
Epoch 7, Validation Accuracy for attr_7: 0.8464
Epoch 7, Validation Accuracy for attr_8: 0.8239
Epoch 7, Validation Accuracy for attr_9: 0.8373
Epoch 7, Validation Accuracy for attr_10: 0.8218


                                                                      

Epoch 8, Validation Accuracy for attr_1: 0.6906
Epoch 8, Validation Accuracy for attr_2: 0.7639
Epoch 8, Validation Accuracy for attr_3: 0.8527
Epoch 8, Validation Accuracy for attr_4: 0.7345
Epoch 8, Validation Accuracy for attr_5: 0.9902
Epoch 8, Validation Accuracy for attr_6: 0.9011
Epoch 8, Validation Accuracy for attr_7: 0.8453
Epoch 8, Validation Accuracy for attr_8: 0.8292
Epoch 8, Validation Accuracy for attr_9: 0.8485
Epoch 8, Validation Accuracy for attr_10: 0.8211


                                                                      

Epoch 9, Validation Accuracy for attr_1: 0.6917
Epoch 9, Validation Accuracy for attr_2: 0.7713
Epoch 9, Validation Accuracy for attr_3: 0.8558
Epoch 9, Validation Accuracy for attr_4: 0.7397
Epoch 9, Validation Accuracy for attr_5: 0.9902
Epoch 9, Validation Accuracy for attr_6: 0.9021
Epoch 9, Validation Accuracy for attr_7: 0.8443
Epoch 9, Validation Accuracy for attr_8: 0.8358
Epoch 9, Validation Accuracy for attr_9: 0.8453
Epoch 9, Validation Accuracy for attr_10: 0.8232


                                                                       

Epoch 10, Validation Accuracy for attr_1: 0.6882
Epoch 10, Validation Accuracy for attr_2: 0.7710
Epoch 10, Validation Accuracy for attr_3: 0.8544
Epoch 10, Validation Accuracy for attr_4: 0.7373
Epoch 10, Validation Accuracy for attr_5: 0.9902
Epoch 10, Validation Accuracy for attr_6: 0.9011
Epoch 10, Validation Accuracy for attr_7: 0.8460
Epoch 10, Validation Accuracy for attr_8: 0.8327
Epoch 10, Validation Accuracy for attr_9: 0.8481
Epoch 10, Validation Accuracy for attr_10: 0.8239


                                                        


Test Accuracy per Attribute:
Attribute attr_1: 0.6738
Attribute attr_2: 0.7685
Attribute attr_3: 0.8730
Attribute attr_4: 0.7531
Attribute attr_5: 0.9895
Attribute attr_6: 0.9004
Attribute attr_7: 0.8562
Attribute attr_8: 0.8327
Attribute attr_9: 0.8373
Attribute attr_10: 0.8222




In [None]:
# Prepare and process data
women_tshirts_df = women_tshirts_df.drop(columns=['Category', 'len'])
women_tshirts_df = women_tshirts_df.apply(
    lambda col: col.fillna('dummy_value') if col.isna().all() 
    else col.fillna(col.mode()[0]) if col.dtype == 'object' or col.dtype.name == 'category' 
    else col
)

# Define attribute columns and category mappings
attribute_columns_women_tshirts = [f'attr_{i}' for i in range(1, 11)]
category_mappings_women_tshirts = {}

for col in attribute_columns_women_tshirts:
    women_tshirts_df[col] = pd.Categorical(women_tshirts_df[col])
    category_mappings_women_tshirts[col] = dict(enumerate(women_tshirts_df[col].cat.categories))
    women_tshirts_df[col] = women_tshirts_df[col].cat.codes

women_tshirts_df['image_path'] = '/kaggle/input/visual-taxonomy/train_images/' + women_tshirts_df['id'].apply(lambda x: f"{x:06d}.jpg")

# Split dataset into train, validation, and test sets
train_df_women_tshirts, temp_df = train_test_split(women_tshirts_df, test_size=0.3, random_state=42)
val_df_women_tshirts, test_df_women_tshirts = train_test_split(temp_df, test_size=0.5, random_state=42)

# Define number of classes for each attribute
num_classes_per_attribute_women_tshirts = {
    'attr_1': 7,
    'attr_2': 3,
    'attr_3': 3,
    'attr_4': 3,
    'attr_5': 6,
    'attr_6': 3,
    'attr_7': 2,
    'attr_8': 2,
    'attr_9': 1,
    'attr_10': 1
}

class MultiOutputModel(nn.Module):
    def __init__(self, num_classes_per_attribute):
        super(MultiOutputModel, self).__init__()
        
        # Load ResNet101 as the base model
        self.base_model = models.resnet101(pretrained=True)
        
        # Save the original fc's in_features before replacing it
        in_features = self.base_model.fc.in_features
        
        # Replace the fully connected layer with an identity layer
        self.base_model.fc = nn.Identity()
        
        # Add dropout
        self.dropout = nn.Dropout(0.3)
        
        # Create classification heads for each attribute
        self.heads = nn.ModuleDict({
            attr: nn.Linear(in_features, num_classes)
            for attr, num_classes in num_classes_per_attribute.items()
        })

    def forward(self, x):
        # Extract features using the base model
        features = self.base_model(x)
        
        # Apply dropout
        features = self.dropout(features)
        
        # Compute logits for each attribute
        logits = {attr: head(features) for attr, head in self.heads.items()}
        
        return logits


# Initialize the ResNet101-based model
model_resnet101_women_tshirts = MultiOutputModel(num_classes_per_attribute_women_tshirts)

# Data Augmentation
data_augmentation = T.Compose([
    T.RandomResizedCrop(224),
    T.RandomHorizontalFlip(),
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Dataset and DataLoader
class CustomDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        id_str = str(self.df.iloc[idx]["id"]).zfill(6)
        path = self.df.iloc[idx]["image_path"]
        image = Image.open(path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        labels = {attr: torch.tensor(self.df.iloc[idx][attr], dtype=torch.long) for attr in attribute_columns_women_tshirts}
        return image, labels

# Initialize datasets and data loaders
train_dataset_women_tshirts = CustomDataset(train_df_women_tshirts, transform=data_augmentation)
val_dataset_women_tshirts = CustomDataset(val_df_women_tshirts, transform=T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]))
test_dataset_women_tshirts = CustomDataset(test_df_women_tshirts, transform=T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]))

train_loader_women_tshirts = DataLoader(train_dataset_women_tshirts, batch_size=32, shuffle=True, num_workers=4)
val_loader_women_tshirts = DataLoader(val_dataset_women_tshirts, batch_size=32, shuffle=False, num_workers=4)
test_loader_women_tshirts = DataLoader(test_dataset_women_tshirts, batch_size=32, shuffle=False, num_workers=4)

# Device Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_resnet101_women_tshirts.to(device)

# Optimizer and Learning Rate Scheduler
optimizer = torch.optim.AdamW(model_resnet101_women_tshirts.parameters(), lr=5e-5, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
scaler = GradScaler()

# Training and Validation Loop
num_epochs = 10

for epoch in range(num_epochs):
    # Training
    model_resnet101_women_tshirts.train()
    for images, labels in tqdm(train_loader_women_tshirts, desc=f"Training Epoch {epoch + 1}/{num_epochs}", leave=False):
        images = images.to(device)
        optimizer.zero_grad()

        with autocast():
            outputs = model_resnet101_women_tshirts(images)
            losses = {attr: nn.CrossEntropyLoss()(outputs[attr], labels[attr].to(device)) for attr in labels}
            total_loss = sum(losses.values())

        scaler.scale(total_loss).backward()
        scaler.step(optimizer)
        scaler.update()

    scheduler.step()

    # Validation
    model_resnet101_women_tshirts.eval()
    val_corrects = {attr: 0 for attr in attribute_columns_women_tshirts}
    val_samples = {attr: 0 for attr in attribute_columns_women_tshirts}

    with torch.no_grad():
        for images, labels in tqdm(val_loader_women_tshirts, desc="Validating", leave=False):
            images = images.to(device)
            outputs = model_resnet101_women_tshirts(images)
            for attr in labels:
                target = labels[attr].to(device)
                _, preds = torch.max(outputs[attr], 1)
                val_corrects[attr] += torch.sum(preds == target).item()
                val_samples[attr] += target.size(0)

    # Validation Accuracy per Attribute
    for attr in attribute_columns_women_tshirts:
        acc = val_corrects[attr] / val_samples[attr] if val_samples[attr] > 0 else 0
        print(f"Epoch {epoch + 1}, Validation Accuracy for {attr}: {acc:.4f}")

# Evaluate the model on the test set
model_resnet101_women_tshirts.eval()
test_corrects = {attr: 0 for attr in num_classes_per_attribute_women_tshirts.keys()}
test_samples = {attr: 0 for attr in num_classes_per_attribute_women_tshirts.keys()}

with torch.no_grad():
    for images, labels in tqdm(test_loader_women_tshirts, desc="Testing", leave=False):
        images = images.to(device)
        
        # Get model predictions
        outputs = model_resnet101_women_tshirts(images)
        
        for attr, target in labels.items():
            # Move target to the device
            target = target.to(device)
            
            # Get predicted class labels
            _, preds = torch.max(outputs[attr], 1)
            
            # Compute accuracy
            test_corrects[attr] += torch.sum(preds == target).item()
            test_samples[attr] += target.size(0)

# Print test accuracy for each attribute
print("\nTest Accuracy per Attribute:")
for attr in num_classes_per_attribute_women_tshirts.keys():
    test_acc = test_corrects[attr] / test_samples[attr] if test_samples[attr] > 0 else 0
    print(f"Attribute {attr}: {test_acc:.4f}")


  scaler = GradScaler()
  self.pid = os.fork()
  with autocast():
  self.pid = os.fork()
                                                                      

Epoch 1, Validation Accuracy for attr_1: 0.7536
Epoch 1, Validation Accuracy for attr_2: 0.8999
Epoch 1, Validation Accuracy for attr_3: 0.7784
Epoch 1, Validation Accuracy for attr_4: 0.9499
Epoch 1, Validation Accuracy for attr_5: 0.6772
Epoch 1, Validation Accuracy for attr_6: 0.9386
Epoch 1, Validation Accuracy for attr_7: 0.9652
Epoch 1, Validation Accuracy for attr_8: 0.9986
Epoch 1, Validation Accuracy for attr_9: 1.0000
Epoch 1, Validation Accuracy for attr_10: 1.0000


                                                                      

Epoch 2, Validation Accuracy for attr_1: 0.7599
Epoch 2, Validation Accuracy for attr_2: 0.9041
Epoch 2, Validation Accuracy for attr_3: 0.7947
Epoch 2, Validation Accuracy for attr_4: 0.9606
Epoch 2, Validation Accuracy for attr_5: 0.7088
Epoch 2, Validation Accuracy for attr_6: 0.9450
Epoch 2, Validation Accuracy for attr_7: 0.9616
Epoch 2, Validation Accuracy for attr_8: 0.9986
Epoch 2, Validation Accuracy for attr_9: 1.0000
Epoch 2, Validation Accuracy for attr_10: 1.0000


                                                                      

Epoch 3, Validation Accuracy for attr_1: 0.7731
Epoch 3, Validation Accuracy for attr_2: 0.9084
Epoch 3, Validation Accuracy for attr_3: 0.8097
Epoch 3, Validation Accuracy for attr_4: 0.9574
Epoch 3, Validation Accuracy for attr_5: 0.7109
Epoch 3, Validation Accuracy for attr_6: 0.9478
Epoch 3, Validation Accuracy for attr_7: 0.9670
Epoch 3, Validation Accuracy for attr_8: 0.9986
Epoch 3, Validation Accuracy for attr_9: 1.0000
Epoch 3, Validation Accuracy for attr_10: 1.0000


                                                                      

Epoch 4, Validation Accuracy for attr_1: 0.7926
Epoch 4, Validation Accuracy for attr_2: 0.9073
Epoch 4, Validation Accuracy for attr_3: 0.8065
Epoch 4, Validation Accuracy for attr_4: 0.9620
Epoch 4, Validation Accuracy for attr_5: 0.7266
Epoch 4, Validation Accuracy for attr_6: 0.9467
Epoch 4, Validation Accuracy for attr_7: 0.9698
Epoch 4, Validation Accuracy for attr_8: 0.9986
Epoch 4, Validation Accuracy for attr_9: 1.0000
Epoch 4, Validation Accuracy for attr_10: 1.0000


                                                                      

Epoch 5, Validation Accuracy for attr_1: 0.7848
Epoch 5, Validation Accuracy for attr_2: 0.9048
Epoch 5, Validation Accuracy for attr_3: 0.8093
Epoch 5, Validation Accuracy for attr_4: 0.9620
Epoch 5, Validation Accuracy for attr_5: 0.7205
Epoch 5, Validation Accuracy for attr_6: 0.9489
Epoch 5, Validation Accuracy for attr_7: 0.9698
Epoch 5, Validation Accuracy for attr_8: 0.9986
Epoch 5, Validation Accuracy for attr_9: 1.0000
Epoch 5, Validation Accuracy for attr_10: 1.0000


                                                                      

Epoch 6, Validation Accuracy for attr_1: 0.7873
Epoch 6, Validation Accuracy for attr_2: 0.9077
Epoch 6, Validation Accuracy for attr_3: 0.7994
Epoch 6, Validation Accuracy for attr_4: 0.9641
Epoch 6, Validation Accuracy for attr_5: 0.7219
Epoch 6, Validation Accuracy for attr_6: 0.9460
Epoch 6, Validation Accuracy for attr_7: 0.9673
Epoch 6, Validation Accuracy for attr_8: 0.9986
Epoch 6, Validation Accuracy for attr_9: 1.0000
Epoch 6, Validation Accuracy for attr_10: 1.0000


                                                                      

Epoch 7, Validation Accuracy for attr_1: 0.7972
Epoch 7, Validation Accuracy for attr_2: 0.9112
Epoch 7, Validation Accuracy for attr_3: 0.8196
Epoch 7, Validation Accuracy for attr_4: 0.9656
Epoch 7, Validation Accuracy for attr_5: 0.7259
Epoch 7, Validation Accuracy for attr_6: 0.9499
Epoch 7, Validation Accuracy for attr_7: 0.9659
Epoch 7, Validation Accuracy for attr_8: 0.9986
Epoch 7, Validation Accuracy for attr_9: 1.0000
Epoch 7, Validation Accuracy for attr_10: 1.0000


                                                                      

Epoch 8, Validation Accuracy for attr_1: 0.7905
Epoch 8, Validation Accuracy for attr_2: 0.9091
Epoch 8, Validation Accuracy for attr_3: 0.8232
Epoch 8, Validation Accuracy for attr_4: 0.9638
Epoch 8, Validation Accuracy for attr_5: 0.7283
Epoch 8, Validation Accuracy for attr_6: 0.9478
Epoch 8, Validation Accuracy for attr_7: 0.9670
Epoch 8, Validation Accuracy for attr_8: 0.9986
Epoch 8, Validation Accuracy for attr_9: 1.0000
Epoch 8, Validation Accuracy for attr_10: 1.0000


                                                                      

Epoch 9, Validation Accuracy for attr_1: 0.7987
Epoch 9, Validation Accuracy for attr_2: 0.9105
Epoch 9, Validation Accuracy for attr_3: 0.8246
Epoch 9, Validation Accuracy for attr_4: 0.9648
Epoch 9, Validation Accuracy for attr_5: 0.7308
Epoch 9, Validation Accuracy for attr_6: 0.9482
Epoch 9, Validation Accuracy for attr_7: 0.9684
Epoch 9, Validation Accuracy for attr_8: 0.9986
Epoch 9, Validation Accuracy for attr_9: 1.0000
Epoch 9, Validation Accuracy for attr_10: 1.0000


                                                                       

Epoch 10, Validation Accuracy for attr_1: 0.8047
Epoch 10, Validation Accuracy for attr_2: 0.9094
Epoch 10, Validation Accuracy for attr_3: 0.8249
Epoch 10, Validation Accuracy for attr_4: 0.9659
Epoch 10, Validation Accuracy for attr_5: 0.7319
Epoch 10, Validation Accuracy for attr_6: 0.9489
Epoch 10, Validation Accuracy for attr_7: 0.9698
Epoch 10, Validation Accuracy for attr_8: 0.9986
Epoch 10, Validation Accuracy for attr_9: 1.0000
Epoch 10, Validation Accuracy for attr_10: 1.0000


                                                        


Test Accuracy per Attribute:
Attribute attr_1: 0.7941
Attribute attr_2: 0.9049
Attribute attr_3: 0.8332
Attribute attr_4: 0.9684
Attribute attr_5: 0.7341
Attribute attr_6: 0.9489
Attribute attr_7: 0.9727
Attribute attr_8: 0.9996
Attribute attr_9: 1.0000
Attribute attr_10: 1.0000




In [48]:
import torch
from PIL import Image
import torchvision.transforms as T

def predict(image_path, model, category_mappings):
    # Load and preprocess the image
    image = Image.open(image_path).convert("RGB")
    transform = T.Compose([
        T.Resize((224, 224)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    image = transform(image).unsqueeze(0).to(device)  # Add batch dimension and move to device
    
    # Make the prediction
    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():
        outputs = model(image)

    # Convert the predictions (logits) to class labels
    predicted_labels = {}
    for attr, output in outputs.items():
        _, preds = torch.max(output, 1)  # Get the predicted class label
        predicted_labels[attr] = category_mappings[attr][preds.item()]  # Map to corresponding value

    return predicted_labels

# Example usage: Predict for a sample image
image_path = "/kaggle/input/visual-taxonomy/test_images/000000.jpg"
predictions = predict_image(image_path, model_resnet101, category_mappings_men_tshirts)

# Print the predictions
print("Predictions for the image:")
for attr, label in predictions.items():
    print(f"{attr}: {label}")


Predictions for the image:
attr_1: default
attr_2: polo
attr_3: solid
attr_4: solid
attr_5: short sleeves
attr_6: dummy_value
attr_7: dummy_value
attr_8: dummy_value
attr_9: dummy_value
attr_10: dummy_value


In [41]:
# Save the model
torch.save(model_resnet101.state_dict(), 'model_vit_men_tshirts_resnet101_new.pth')
torch.save(model_resnet101_women_tops.state_dict(), 'model_vit_women_tops_resnet101_new.pth')
torch.save(model_resnet101_women_tshirts.state_dict(), 'model_vit_women_tshirts_resnet101_new.pth')
torch.save(model_kurtis.state_dict(), 'model_vit_kurtis_resnet101_new.pth')
torch.save(model_resnet101_sarees.state_dict(), 'model_vit_sarees_resnet101_new.pth')


In [49]:
from tqdm import tqdm
import contextlib
from IPython.display import HTML
from tensorflow.keras.preprocessing.image import load_img, img_to_array
import pandas as pd
import io

# Mapping of categories to their respective 'lens' values
categories_number = {
    'Men Tshirts': 5, 
    'Sarees': 10, 
    'Kurtis': 9, 
    'Women Tshirts': 8, 
    'Women Tops & Tunics': 10
}

model_selection = {
    'Men Tshirts': model_resnet101,
    'Sarees': model_resnet101_sarees,
    'Kurtis': model_kurtis,
    'Women Tshirts': model_resnet101_women_tshirts,
    'Women Tops & Tunics': model_resnet101_women_tops
}

mapping_selection = {
    'Men Tshirts': category_mappings_men_tshirts,
    'Sarees': category_mappings_sarees,
    'Kurtis': category_mappings_kurtis,
    'Women Tshirts': category_mappings_women_tshirts,
    'Women Tops & Tunics': category_mappings_women_tops
}

# Initialize an empty list to store the final results
final_results = []

# Iterate over the test DataFrame with tqdm for a single progress bar
for idx, row in tqdm(test_df.iterrows(), total=len(test_df), desc="Processing Images"):
    category = row['Category']
    
    # Suppress output by redirecting stdout temporarily
    with contextlib.redirect_stdout(io.StringIO()):
        image_path = '/kaggle/input/visual-taxonomy/test_images/' + "{:06}".format(row['id']) + '.jpg'
        predictions = predict(image_path, model_selection[category], mapping_selection[category])
    
    # Extract attributes from the predictions
    result = {
        'id': row['id'],
        'Category': category,
        'len': categories_number[category],  # Set 'len' according to the category
        **predictions  # Unpack the dictionary of attributes into the result
    }
    
    final_results.append(result)

# Create the final DataFrame from the results
final_df = pd.DataFrame(final_results)

# Save the final DataFrame to a CSV file
final_df.to_csv('final_predictions_resnet101_new.csv', index=False)

print("Final CSV generated and saved as 'final_predictions_resnet101_new.csv'.")

Processing Images: 100%|██████████| 30205/30205 [16:43<00:00, 30.11it/s]


Final CSV generated and saved as 'final_predictions_resnet101_new.csv'.


Use the model to generate the train dataset predictions to be used for the final stacked ensemble model

In [56]:
from tqdm import tqdm
import contextlib
from IPython.display import HTML
from tensorflow.keras.preprocessing.image import load_img, img_to_array
import pandas as pd
import io

# Mapping of categories to their respective 'lens' values
categories_number = {
    'Men Tshirts': 5, 
    'Sarees': 10, 
    'Kurtis': 9, 
    'Women Tshirts': 8, 
    'Women Tops & Tunics': 10
}

model_selection = {
    'Men Tshirts': model_resnet101,
    'Sarees': model_resnet101_sarees,
    'Kurtis': model_kurtis,
    'Women Tshirts': model_resnet101_women_tshirts,
    'Women Tops & Tunics': model_resnet101_women_tops
}

mapping_selection = {
    'Men Tshirts': category_mappings_men_tshirts,
    'Sarees': category_mappings_sarees,
    'Kurtis': category_mappings_kurtis,
    'Women Tshirts': category_mappings_women_tshirts,
    'Women Tops & Tunics': category_mappings_women_tops
}

# Initialize an empty list to store the final results
final_results = []

# Iterate over the train DataFrame with tqdm for a single progress bar
for idx, row in tqdm(train_df.iterrows(), total=len(train_df), desc="Processing Images"):
    category = row['Category']
    
    # Suppress output by redirecting stdout temporarily
    with contextlib.redirect_stdout(io.StringIO()):
        image_path = '/kaggle/input/visual-taxonomy/train_images/' + "{:06}".format(row['id']) + '.jpg'
        predictions = predict(image_path, model_selection[category], mapping_selection[category])
    
    # Extract attributes from the predictions
    result = {
        'id': row['id'],
        'Category': category,
        'len': categories_number[category],  # Set 'len' according to the category
        **predictions  # Unpack the dictionary of attributes into the result
    }
    
    final_results.append(result)

# Create the final DataFrame from the results
final_df = pd.DataFrame(final_results)

# Save the final DataFrame to a CSV file
final_df.to_csv('final_predictions_resnet101_new_train.csv', index=False)

print("Final CSV generated and saved as 'final_predictions_resnet101_new_train.csv'.")

Processing Images: 100%|██████████| 70213/70213 [30:20<00:00, 38.57it/s]


Final CSV generated and saved as 'final_predictions_resnet101_new_train.csv'.


In [7]:
category_mappings_men_tshirts = {'attr_1': {0: 'black', 1: 'default', 2: 'multicolor', 3: 'white'}, 'attr_2': {0: 'polo', 1: 'round'}, 'attr_3': {0: 'printed', 1: 'solid'}, 'attr_4': {0: 'default', 1: 'solid', 2: 'typography'}, 'attr_5': {0: 'long sleeves', 1: 'short sleeves'}, 'attr_6': {0: 'dummy_value'}, 'attr_7': {0: 'dummy_value'}, 'attr_8': {0: 'dummy_value'}, 'attr_9': {0: 'dummy_value'}, 'attr_10': {0: 'dummy_value'}}
category_mappings_sarees = {'attr_1': {0: 'default', 1: 'same as border', 2: 'same as saree', 3: 'solid'}, 'attr_2': {0: 'default', 1: 'no border', 2: 'solid', 3: 'temple border', 4: 'woven design', 5: 'zari'}, 'attr_3': {0: 'big border', 1: 'no border', 2: 'small border'}, 'attr_4': {0: 'cream', 1: 'default', 2: 'green', 3: 'multicolor', 4: 'navy blue', 5: 'pink', 6: 'white', 7: 'yellow'}, 'attr_5': {0: 'daily', 1: 'party', 2: 'traditional', 3: 'wedding'}, 'attr_6': {0: 'default', 1: 'jacquard', 2: 'tassels and latkans'}, 'attr_7': {0: 'default', 1: 'same as saree', 2: 'woven design', 3: 'zari woven'}, 'attr_8': {0: 'default', 1: 'printed', 2: 'solid', 3: 'woven design', 4: 'zari woven'}, 'attr_9': {0: 'applique', 1: 'botanical', 2: 'checked', 3: 'default', 4: 'elephant', 5: 'ethnic motif', 6: 'floral', 7: 'peacock', 8: 'solid'}, 'attr_10': {0: 'no', 1: 'yes'}}
category_mappings_kurtis = {'attr_1': {0: 'black', 1: 'blue', 2: 'green', 3: 'grey', 4: 'maroon', 5: 'multicolor', 6: 'navy blue', 7: 'orange', 8: 'pink', 9: 'purple', 10: 'red', 11: 'white', 12: 'yellow'}, 'attr_2': {0: 'a-line', 1: 'straight'}, 'attr_3': {0: 'calf length', 1: 'knee length'}, 'attr_4': {0: 'daily', 1: 'party'}, 'attr_5': {0: 'default', 1: 'net'}, 'attr_6': {0: 'default', 1: 'solid'}, 'attr_7': {0: 'default', 1: 'solid'}, 'attr_8': {0: 'short sleeves', 1: 'sleeveless', 2: 'three-quarter sleeves'}, 'attr_9': {0: 'regular', 1: 'sleeveless'}, 'attr_10': {0: 'dummy_value'}}
category_mappings_women_tshirts = {'attr_1': {0: 'black', 1: 'default', 2: 'maroon', 3: 'multicolor', 4: 'pink', 5: 'white', 6: 'yellow'}, 'attr_2': {0: 'boxy', 1: 'loose', 2: 'regular'}, 'attr_3': {0: 'crop', 1: 'long', 2: 'regular'}, 'attr_4': {0: 'default', 1: 'printed', 2: 'solid'}, 'attr_5': {0: 'default', 1: 'funky print', 2: 'graphic', 3: 'quirky', 4: 'solid', 5: 'typography'}, 'attr_6': {0: 'default', 1: 'long sleeves', 2: 'short sleeves'}, 'attr_7': {0: 'cuffed sleeves', 1: 'regular sleeves'}, 'attr_8': {0: 'applique', 1: 'default'}, 'attr_9': {0: 'dummy_value'}, 'attr_10': {0: 'dummy_value'}}
category_mappings_women_tops = {'attr_1': {0: 'black', 1: 'blue', 2: 'default', 3: 'green', 4: 'maroon', 5: 'multicolor', 6: 'navy blue', 7: 'peach', 8: 'pink', 9: 'red', 10: 'white', 11: 'yellow'}, 'attr_2': {0: 'boxy', 1: 'default', 2: 'fitted', 3: 'regular'}, 'attr_3': {0: 'crop', 1: 'regular'}, 'attr_4': {0: 'default', 1: 'high', 2: 'round neck', 3: 'square neck', 4: 'stylised', 5: 'sweetheart neck', 6: 'v-neck'}, 'attr_5': {0: 'casual', 1: 'party'}, 'attr_6': {0: 'default', 1: 'printed', 2: 'solid'}, 'attr_7': {0: 'default', 1: 'floral', 2: 'graphic', 3: 'quirky', 4: 'solid', 5: 'typography'}, 'attr_8': {0: 'long sleeves', 1: 'short sleeves', 2: 'sleeveless', 3: 'three-quarter sleeves'}, 'attr_9': {0: 'default', 1: 'puff sleeves', 2: 'regular sleeves', 3: 'sleeveless'}, 'attr_10': {0: 'applique', 1: 'default', 2: 'knitted', 3: 'ruffles', 4: 'tie-ups', 5: 'waist tie-ups'}}

In [51]:
print(category_mappings_men_tshirts)

{'attr_1': {0: 'black', 1: 'default', 2: 'multicolor', 3: 'white'}, 'attr_2': {0: 'polo', 1: 'round'}, 'attr_3': {0: 'printed', 1: 'solid'}, 'attr_4': {0: 'default', 1: 'solid', 2: 'typography'}, 'attr_5': {0: 'long sleeves', 1: 'short sleeves'}, 'attr_6': {0: 'dummy_value'}, 'attr_7': {0: 'dummy_value'}, 'attr_8': {0: 'dummy_value'}, 'attr_9': {0: 'dummy_value'}, 'attr_10': {0: 'dummy_value'}}


In [52]:
print(category_mappings_sarees)

{'attr_1': {0: 'default', 1: 'same as border', 2: 'same as saree', 3: 'solid'}, 'attr_2': {0: 'default', 1: 'no border', 2: 'solid', 3: 'temple border', 4: 'woven design', 5: 'zari'}, 'attr_3': {0: 'big border', 1: 'no border', 2: 'small border'}, 'attr_4': {0: 'cream', 1: 'default', 2: 'green', 3: 'multicolor', 4: 'navy blue', 5: 'pink', 6: 'white', 7: 'yellow'}, 'attr_5': {0: 'daily', 1: 'party', 2: 'traditional', 3: 'wedding'}, 'attr_6': {0: 'default', 1: 'jacquard', 2: 'tassels and latkans'}, 'attr_7': {0: 'default', 1: 'same as saree', 2: 'woven design', 3: 'zari woven'}, 'attr_8': {0: 'default', 1: 'printed', 2: 'solid', 3: 'woven design', 4: 'zari woven'}, 'attr_9': {0: 'applique', 1: 'botanical', 2: 'checked', 3: 'default', 4: 'elephant', 5: 'ethnic motif', 6: 'floral', 7: 'peacock', 8: 'solid'}, 'attr_10': {0: 'no', 1: 'yes'}}


In [53]:
print(category_mappings_kurtis)

{'attr_1': {0: 'black', 1: 'blue', 2: 'green', 3: 'grey', 4: 'maroon', 5: 'multicolor', 6: 'navy blue', 7: 'orange', 8: 'pink', 9: 'purple', 10: 'red', 11: 'white', 12: 'yellow'}, 'attr_2': {0: 'a-line', 1: 'straight'}, 'attr_3': {0: 'calf length', 1: 'knee length'}, 'attr_4': {0: 'daily', 1: 'party'}, 'attr_5': {0: 'default', 1: 'net'}, 'attr_6': {0: 'default', 1: 'solid'}, 'attr_7': {0: 'default', 1: 'solid'}, 'attr_8': {0: 'short sleeves', 1: 'sleeveless', 2: 'three-quarter sleeves'}, 'attr_9': {0: 'regular', 1: 'sleeveless'}, 'attr_10': {0: 'dummy_value'}}


In [54]:
print(category_mappings_women_tshirts)

{'attr_1': {0: 'black', 1: 'default', 2: 'maroon', 3: 'multicolor', 4: 'pink', 5: 'white', 6: 'yellow'}, 'attr_2': {0: 'boxy', 1: 'loose', 2: 'regular'}, 'attr_3': {0: 'crop', 1: 'long', 2: 'regular'}, 'attr_4': {0: 'default', 1: 'printed', 2: 'solid'}, 'attr_5': {0: 'default', 1: 'funky print', 2: 'graphic', 3: 'quirky', 4: 'solid', 5: 'typography'}, 'attr_6': {0: 'default', 1: 'long sleeves', 2: 'short sleeves'}, 'attr_7': {0: 'cuffed sleeves', 1: 'regular sleeves'}, 'attr_8': {0: 'applique', 1: 'default'}, 'attr_9': {0: 'dummy_value'}, 'attr_10': {0: 'dummy_value'}}


In [55]:
print(category_mappings_women_tops)

{'attr_1': {0: 'black', 1: 'blue', 2: 'default', 3: 'green', 4: 'maroon', 5: 'multicolor', 6: 'navy blue', 7: 'peach', 8: 'pink', 9: 'red', 10: 'white', 11: 'yellow'}, 'attr_2': {0: 'boxy', 1: 'default', 2: 'fitted', 3: 'regular'}, 'attr_3': {0: 'crop', 1: 'regular'}, 'attr_4': {0: 'default', 1: 'high', 2: 'round neck', 3: 'square neck', 4: 'stylised', 5: 'sweetheart neck', 6: 'v-neck'}, 'attr_5': {0: 'casual', 1: 'party'}, 'attr_6': {0: 'default', 1: 'printed', 2: 'solid'}, 'attr_7': {0: 'default', 1: 'floral', 2: 'graphic', 3: 'quirky', 4: 'solid', 5: 'typography'}, 'attr_8': {0: 'long sleeves', 1: 'short sleeves', 2: 'sleeveless', 3: 'three-quarter sleeves'}, 'attr_9': {0: 'default', 1: 'puff sleeves', 2: 'regular sleeves', 3: 'sleeveless'}, 'attr_10': {0: 'applique', 1: 'default', 2: 'knitted', 3: 'ruffles', 4: 'tie-ups', 5: 'waist tie-ups'}}
