In [1]:
import os
import cv2
import tqdm
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report, accuracy_score
from warnings import filterwarnings
filterwarnings('ignore')

In [2]:
# import torch
# if torch.backends.mps.is_available():
#     mps_device = torch.device("mps")
#     x = torch.ones(1, device=mps_device)
#     print (x)
# else:
#     print ("MPS device not found.")

## CNN

In [3]:
# Global image directory path
IMG_DIR = "/Users/susanketsarkar/Desktop/Code/Meesho/data/train_images"  # Update with your image directory

class CustomImageDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.dataframe = dataframe
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        img_path = self.dataframe.iloc[idx]['image_path']
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.transform:
            image = self.transform(image)

        label = self.dataframe.iloc[idx]['label']
        return image, label

def load_data(csv_path, attr_to_predict):
    # Load data
    df = pd.read_csv(csv_path)
    
    # Extract relevant columns
    df = df[['id', 'Category', attr_to_predict]]
    
    # Drop rows with missing values in the target attribute
    df.dropna(subset=[attr_to_predict], inplace=True)

    # Create image paths
    df['image_path'] = df['id'].apply(lambda x: os.path.join(IMG_DIR, f"{str(x).zfill(6)}.jpg"))
    
    return df

def preprocess_data(df, attr_to_predict):
    # Encode labels
    le = LabelEncoder()
    df['label'] = le.fit_transform(df[attr_to_predict])
    
    return df, le

def build_cnn_model(num_classes):
    model = nn.Sequential(
        nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        
        nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        
        nn.Flatten(),
        nn.Linear(64 * 16 * 16, 128),
        nn.ReLU(),
        nn.Linear(128, num_classes)
    )
    return model

def train_cnn_on_attribute(csv_path, attr_to_predict, epochs=10, batch_size=4):
    # Load and prepare data
    df = load_data(csv_path, attr_to_predict)
    df, le = preprocess_data(df, attr_to_predict)

    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
    ])
    
    dataset = CustomImageDataset(dataframe=df, transform=transform)
    
    # Split the dataset
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    # Build CNN model
    print("Building the model...")
    model = build_cnn_model(num_classes=len(le.classes_))
    
    # Move model to GPU if available
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    model.to(device)

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Train the model
    print(f"Training the model on {train_size} data points...")
    model.train()
    for epoch in range(epochs):
        print(f"Epoch {epoch} running...")
        for images, labels in tqdm.tqdm(train_loader):
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    # Evaluate the model
    print("Evaluating the model...")
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            all_preds.append(preds.cpu().numpy())
            all_labels.append(labels.cpu().numpy())

    all_preds = np.concatenate(all_preds)
    all_labels = np.concatenate(all_labels)

    # Print metrics
    print("Classification Report:")
    print(classification_report(all_labels, all_preds, target_names=le.classes_))
    print(f"Accuracy: {accuracy_score(all_labels, all_preds) * 100:.2f}%")

In [12]:
csv_path = "../data/cat_wise_csv/Men_Tshirts_data.csv"  
attr_to_predict = 'sleeve_length'  
train_cnn_on_attribute(csv_path, attr_to_predict, epochs=10, batch_size=32)

Building the model...
Training the model on 4781 data points...
Epoch 0 running...


100%|██████████| 150/150 [00:23<00:00,  6.45it/s]


Epoch 1 running...


100%|██████████| 150/150 [00:15<00:00,  9.82it/s]


Epoch 2 running...


100%|██████████| 150/150 [00:20<00:00,  7.36it/s]


Epoch 3 running...


100%|██████████| 150/150 [00:12<00:00, 12.19it/s]


Epoch 4 running...


100%|██████████| 150/150 [00:15<00:00,  9.59it/s]


Epoch 5 running...


100%|██████████| 150/150 [00:15<00:00,  9.62it/s]


Epoch 6 running...


100%|██████████| 150/150 [00:17<00:00,  8.60it/s]


Epoch 7 running...


100%|██████████| 150/150 [00:15<00:00,  9.42it/s]


Epoch 8 running...


100%|██████████| 150/150 [00:15<00:00,  9.70it/s]


Epoch 9 running...


100%|██████████| 150/150 [00:19<00:00,  7.81it/s]


Evaluating the model...
Classification Report:
               precision    recall  f1-score   support

 long sleeves       0.82      0.75      0.78        71
short sleeves       0.98      0.99      0.99      1125

     accuracy                           0.97      1196
    macro avg       0.90      0.87      0.88      1196
 weighted avg       0.97      0.97      0.97      1196

Accuracy: 97.49%
