In [1]:
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import torch
import pandas as pd
from PIL import ImageDraw, ImageFont, Image
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch
import torch.nn as nn
from transformers import ViTModel
from torchinfo import summary  # 
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")
import random
import time

In [2]:
DEVICE="cuda:0"
def setAllSeeds(seed):
  os.environ['MY_GLOBAL_SEED'] = str(seed)
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)
setAllSeeds(42)

In [3]:
def show_fashionImage(image):
    """Show image with landmarks"""
    plt.imshow(image)

In [4]:
category="Men Tshirts"
df = pd.read_csv("train.csv")
df = df[df["Category"]==category]

In [5]:
delCol = []
idxCol = []
trackNum = []
for i in range(1,11):
    uniName = df["attr_"+str(i)].unique()
    # print(len(uniName))
    if(len(uniName)==1):
        delCol.append("attr_"+str(i))
    else:
        idxCol.append("attr_"+str(i))
        trackNum.append(len(uniName))

In [6]:
delCol

['attr_6', 'attr_7', 'attr_8', 'attr_9', 'attr_10']

In [7]:
df = df.drop(delCol,axis=1)

In [8]:
df = df.fillna("NA")

In [9]:
# df=df[0:100]

In [10]:
id2label={}
label2id={}
attrs={}
total_attr=len(df.columns)
for i in range(3,total_attr):
    labels=df[df.columns[i]].unique()
    # print(df.columns[i],labels)
    id2label[i-3]={k:labels[k] for k in range(len(labels))}
    label2id[i-3]={labels[k]:k for k in range(len(labels))}
    attrs[i-3]=df.columns[i]
print(id2label)
print(label2id)
print(attrs)

{0: {0: 'default', 1: 'multicolor', 2: 'black', 3: 'white', 4: 'NA'}, 1: {0: 'round', 1: 'polo', 2: 'NA'}, 2: {0: 'printed', 1: 'solid', 2: 'NA'}, 3: {0: 'default', 1: 'solid', 2: 'NA', 3: 'typography'}, 4: {0: 'short sleeves', 1: 'long sleeves', 2: 'NA'}}
{0: {'default': 0, 'multicolor': 1, 'black': 2, 'white': 3, 'NA': 4}, 1: {'round': 0, 'polo': 1, 'NA': 2}, 2: {'printed': 0, 'solid': 1, 'NA': 2}, 3: {'default': 0, 'solid': 1, 'NA': 2, 'typography': 3}, 4: {'short sleeves': 0, 'long sleeves': 1, 'NA': 2}}
{0: 'attr_1', 1: 'attr_2', 2: 'attr_3', 3: 'attr_4', 4: 'attr_5'}


In [11]:
def categorize(example):
    for i in attrs:
        example[attrs[i]]=label2id[i][example[attrs[i]]]
    return example
df=df.apply(categorize,axis=1)
    

In [12]:
from transformers import ViTImageProcessor
model_name = 'google/vit-base-patch16-224'
processor = ViTImageProcessor.from_pretrained(model_name)

In [13]:
from sklearn.model_selection import train_test_split

train_df, val_df = train_test_split(df, test_size=0.3)
val_df,test_df=train_test_split(val_df,test_size=0.33)

In [14]:
class CustomFashionManager(Dataset):
    def __init__(self,csv_file, root_dir="./",transforms =None):
        self.fashionItems = csv_file
        self.root_dir = root_dir
        self.transforms = transforms
    
    def __len__(self):
        return len(self.fashionItems)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir,f"{self.fashionItems.iloc[idx, 0]:06d}"+'.jpg')
        image = Image.open(img_name)
        attributes = self.fashionItems.iloc[idx, 3:]
        attributes = np.array(attributes)
        attributes = attributes.astype('float')
        # print(attributes.shape)
        # attributes = attributes.astype('float').reshape(-1, len(attributes))
        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        inputs=processor(image, return_tensors='pt')
        inputs['labels']=torch.tensor(attributes, dtype=torch.long)
        return inputs

        # if self.transforms:
        #     sample = self.transforms(sample)

        # return sample


In [15]:
train_fashion_data = CustomFashionManager(csv_file=train_df,
                                    root_dir='train_images')
val_fashion_data = CustomFashionManager(csv_file=val_df,
                                    root_dir='train_images')
test_fashion_data = CustomFashionManager(csv_file=test_df,root_dir='train_images')

fig = plt.figure()
        

<Figure size 640x480 with 0 Axes>

In [80]:
import sys
from typing import List
from transformers import ViTConfig,ViTPreTrainedModel


class CustomConfig(ViTConfig):
    def __init__(self,num_classes_per_label:List[int]=[1],**kwargs):
        super().__init__(**kwargs)
        self.num_classes_per_label = num_classes_per_label

class MultiLabelMultiClassViT(ViTPreTrainedModel):
    config_class=CustomConfig
    def __init__(self, config: CustomConfig) -> None:
        super().__init__(config)

        self.vit = ViTModel(config, add_pooling_layer=False)
        self.classifiers = nn.ModuleList([
            nn.Linear(config.hidden_size, num_classes) 
            for num_classes in config.num_classes_per_label
        ])
        # Initialize weights and apply final processing
        self.post_init()


    def forward(self, pixel_values,labels=None):
        outputs = self.vit(pixel_values).last_hidden_state[:, 0, :]  # CLS token representation
        logits = [classifier(outputs) for classifier in self.classifiers]
        if labels is not None:
            loss=0
            for i in range(len(logits)):
                target=labels[:,i]
                loss += torch.nn.functional.cross_entropy(logits[i], target)
            return {"loss": loss, "logits": logits}
        return {"logits": logits}

# Example usage
num_labels = len(trackNum)  # For example, 5 different labels



In [81]:
from transformers import Trainer, TrainingArguments
from sklearn.metrics import classification_report
batch_size = 32
def collate_fn(batch):
    return {
        'pixel_values': torch.cat([x['pixel_values'] for x in batch],dim=0),
        'labels': torch.stack([x['labels'] for x in batch])
    }

def compute_metrics(pred):
    logits = pred.predictions
    labels=pred.label_ids
    probs = np.stack([np.argmax(logit,axis=1) for logit in logits])
    probs=probs.T
    report=classification_report(labels.flatten(),probs.flatten(),output_dict=True)
    return {'accuracy': report['accuracy'],"macro avg f1":report['macro avg']['f1-score']}

training_args = TrainingArguments(
  output_dir="./vit",
  per_device_train_batch_size=48,
  per_device_eval_batch_size=48,
  evaluation_strategy="epoch",
  save_strategy="epoch",
  num_train_epochs=1,
  fp16=True,
  learning_rate=2e-4,
  save_total_limit=1,
  remove_unused_columns=False,
  push_to_hub=False,
  report_to='wandb',
  load_best_model_at_end=True,
)
config=ViTConfig.from_pretrained(model_name)
config=CustomConfig(num_classes_per_label=trackNum,**config.to_dict())
model = MultiLabelMultiClassViT.from_pretrained(model_name,config=config)
trainer = Trainer(
    model,
    training_args,
    train_dataset=train_fashion_data,
    eval_dataset=val_fashion_data,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    tokenizer=processor,
)

[5, 3, 3, 4, 3]


Some weights of MultiLabelMultiClassViT were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['classifiers.0.bias', 'classifiers.0.weight', 'classifiers.1.bias', 'classifiers.1.weight', 'classifiers.2.bias', 'classifiers.2.weight', 'classifiers.3.bias', 'classifiers.3.weight', 'classifiers.4.bias', 'classifiers.4.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [82]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Macro avg f1
1,No log,2.950229,0.761396,0.600754


TrainOutput(global_step=53, training_loss=3.3023929236070164, metrics={'train_runtime': 63.6514, 'train_samples_per_second': 79.904, 'train_steps_per_second': 0.833, 'total_flos': 3.941807777569751e+17, 'train_loss': 3.3023929236070164, 'epoch': 1.0})

In [83]:
trainer.evaluate(test_fashion_data)

{'eval_loss': 2.8807997703552246,
 'eval_accuracy': 0.7627777777777778,
 'eval_macro avg f1': 0.6150085912468228,
 'eval_runtime': 4.8588,
 'eval_samples_per_second': 148.186,
 'eval_steps_per_second': 1.647,
 'epoch': 1.0}

In [84]:
import gc
del model, trainer
torch.cuda.empty_cache()
gc.collect()

100