In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
from torch.nn import functional as F
import pandas as pd
import cv2
import os
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report, accuracy_score
import tqdm
import numpy as np

from warnings import filterwarnings
filterwarnings("ignore")

In [2]:
IMG_DIR = "../data/train_images/"

In [3]:
# import torch
# if torch.backends.mps.is_available():
#     mps_device = torch.device("mps")
#     x = torch.ones(1, device=mps_device)
#     print (x)
# else:
#     print ("MPS device not found.")

In [48]:
class SpatialAttention(nn.Module):
    def __init__(self, input_channels, kernel_size = 5, padding = 1) -> None:
        super(SpatialAttention, self).__init__()

        self.conv = nn.Conv2d(input_channels, 1, kernel_size=kernel_size, padding=padding)

    def forward(self, x):
        avg_pool = torch.mean(x, dim=1, keepdim=True)
        max_pool, _ = torch.max(x, dim=1, keepdim=True)

        pool = torch.concat([avg_pool, max_pool], dim = 1)

        attention = torch.sigmoid(self.conv(pool))

        attention = F.interpolate(attention, size=x.shape[2:], mode="bilinear", align_corners=False)

        print("attention shape: ", x.shape, (attention*x).shape)

        return attention * x

In [49]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes):
        super(SimpleCNN, self).__init__()
        
        self.conv1 = nn.Conv2d(3, 8, 7, 2, 0)
        self.bn1 = nn.BatchNorm2d(8)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.attn1 = SpatialAttention(8)
        
        self.conv2 = nn.Conv2d(8, 16, 3, 2, 1)
        self.bn2 = nn.BatchNorm2d(16)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.attn2 = SpatialAttention(16)
        
        self.conv3 = nn.Conv2d(16, 32, 3, 1, 1)
        self.bn3 = nn.BatchNorm2d(32)
        self.attn3 = SpatialAttention(32)
        
        # Fully connected layer
        self.fc1 = nn.Linear(32, num_classes)
        self.dropout = nn.Dropout(0.5)  # 50% dropout for regularization
    
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.attn1(x)
        x = self.pool1(x)
        
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.attn2(x)
        x = self.pool2(x)
        
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.attn3(x)
        
        x = x.view(x.size(0), -1)  # Flatten for the FC layer
        
        x = self.dropout(x)
        x = self.fc1(x)
        
        return x

In [50]:
def load_data(csv_path, attr_to_predict):
    # Load data
    df = pd.read_csv(csv_path)
    
    # Extract relevant columns
    df = df[['id', 'Category', attr_to_predict]]
    l1 = len(df)
    
    # Drop rows with missing values in the target attribute
    df.dropna(subset=[attr_to_predict], inplace=True)
    print("Number of nan objects: ", l1-len(df))
    print(f"Number of valid data points: {len(df)}")

    # Create image paths
    df['image_path'] = df['id'].apply(lambda x: os.path.join(IMG_DIR, f"{str(x).zfill(6)}.jpg"))
    
    return df

In [51]:
def preprocess_data(df, attr_to_predict):
    # Encode labels
    le = LabelEncoder()
    df['label'] = le.fit_transform(df[attr_to_predict])
    
    return df, le

In [52]:
class ImageDataLoader(Dataset):
    def __init__(self, df, transform) -> None:
        super().__init__()
        self.df = df
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        image = cv2.imread(self.df.iloc[idx]["image_path"])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        image = self.transform(image)

        label = self.df.iloc[idx]["label"]

        return image, label


In [53]:
CSV_PATH = "../data/cat_wise_csv/Kurtis_data.csv"
ATTR_TO_PRED = "color"

In [54]:
df = load_data(CSV_PATH, ATTR_TO_PRED)
df, le = preprocess_data(df, ATTR_TO_PRED)

transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

dataset = ImageDataLoader(df=df, transform=transform)

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


Number of nan objects:  193
Number of valid data points: 6629


In [55]:
model = SimpleCNN(num_classes = len(le.classes_))

In [56]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model.to(device)
print(f"Device: {device}")

Device: mps


In [57]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [58]:
model.train()
for epoch in range(10):
    print(f"Epoch {epoch} running...")
    for images, labels in tqdm.tqdm(train_loader):
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Epoch 0 running...


  0%|          | 0/166 [00:01<?, ?it/s]


RuntimeError: Given groups=1, weight of size [1, 8, 5, 5], expected input[32, 2, 125, 125] to have 8 channels, but got 2 channels instead

In [None]:
print("Evaluating the model...")
model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, preds = torch.max(outputs, 1)
        all_preds.append(preds.cpu().numpy())
        all_labels.append(labels.cpu().numpy())

all_preds = np.concatenate(all_preds)
all_labels = np.concatenate(all_labels)

# Print metrics
print("Classification Report:")
print(classification_report(all_labels, all_preds, target_names=le.classes_))
print(f"Accuracy: {accuracy_score(all_labels, all_preds) * 100:.2f}%")

Evaluating the model...
Classification Report:
              precision    recall  f1-score   support

       black       0.59      0.85      0.70       282
        blue       0.67      0.18      0.28        89
       green       0.36      0.39      0.38        46
        grey       0.00      0.00      0.00        32
      maroon       0.68      0.92      0.78       216
  multicolor       0.17      0.04      0.06        78
   navy blue       0.70      0.78      0.74       156
      orange       0.00      0.00      0.00         6
        pink       0.85      0.50      0.63        44
      purple       0.57      0.11      0.18        38
         red       0.82      0.92      0.87       250
       white       0.00      0.00      0.00        16
      yellow       0.75      0.52      0.61        73

    accuracy                           0.67      1326
   macro avg       0.47      0.40      0.40      1326
weighted avg       0.63      0.67      0.62      1326

Accuracy: 67.27%


In [108]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report, accuracy_score
from PIL import Image
import tqdm


# Custom Dataset for loading images and multi-label attributes
class CustomImageDataset(Dataset):
    def __init__(self, dataframe, img_dir, attr_columns, transform=None):
        self.dataframe = dataframe
        self.img_dir = img_dir
        self.attr_columns = attr_columns
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        print(self.dataframe.columns)
        img_path = self.dataframe.iloc[idx]['image_path']
        image = Image.open(img_path).convert("RGB")
        labels = self.dataframe.iloc[idx][self.attr_columns].values.astype(int)
        if self.transform:
            image = self.transform(image)
        return image, torch.tensor(labels)

# Dynamic CNN Model with shared and attribute-specific heads
class MultiAttributeCNN(nn.Module):
    def __init__(self, attr_classes):
        super(MultiAttributeCNN, self).__init__()
        
        # Shared CNN layers
        self.shared_blocks = nn.Sequential(
            self._cnn_block(3, 32),
            self._cnn_block(32, 64),
        )
        
        # Attribute-specific heads
        self.attribute_heads = nn.ModuleList([
            nn.Sequential(
                self._cnn_block(64, 128),
                nn.Flatten(),
                nn.Linear(128 * 16 * 16, 256),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(256, attr_classes[attr])
            ) for attr in attr_classes
        ])

    def _cnn_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

    def forward(self, x):
        x = self.shared_blocks(x)
        outputs = [head(x) for head in self.attribute_heads]
        return outputs

# Function to load and preprocess data
def load_data(csv_path):
    df = pd.read_csv(csv_path)
    df['image_path'] = df['id'].apply(lambda x: os.path.join(IMG_DIR, f"{str(x).zfill(6)}.jpg"))
    return df

def preprocess_data(df, attr_to_predict):
    le = LabelEncoder()
    df['label'] = le.fit_transform(df[attr_to_predict])
    return df, le
# Training function for multi-attribute prediction
def train_cnn_on_attributes(csv_path, img_dir, attr_info_csv, category, epochs=10, batch_size=4):
    # Load attribute information
    attr_info_df = pd.read_csv(attr_info_csv)
    category_info = attr_info_df[attr_info_df['Category'] == category]
    attr_list = category_info['Attribute_list'].values[0]
    attr_list = [i.replace("[","").replace("]", "").replace("\'", "").replace("\n", "") for i in attr_list.split(" ")]

    # Load and prepare data
    df = load_data(csv_path)
    # df = df[['id', 'Category'] + attr_list].fillna("0")
    
    # Encoding labels for each attribute
    label_encoders = {attr: LabelEncoder().fit(df[attr]) for attr in attr_list}
    for attr in attr_list:
        df[attr] = label_encoders[attr].transform(df[attr])
    
    # Transformations
    transform = transforms.Compose([
        transforms.CenterCrop(64),
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
    ])
    
    dataset = CustomImageDataset(dataframe=df, img_dir=img_dir, attr_columns=attr_list, transform=transform)
    
    # Split dataset
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    # Define attribute classes
    attr_classes = {attr: len(le.classes_) for attr, le in label_encoders.items()}
    model = MultiAttributeCNN(attr_classes)
    
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    model.to(device)

    # Loss functions and optimizer
    criterion = {attr: nn.CrossEntropyLoss() for attr in attr_list}
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Training loop
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for images, labels in tqdm.tqdm(train_loader):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            
            # Loss calculation for each attribute
            losses = [criterion[attr](output, labels[:, i]) for i, (attr, output) in enumerate(zip(attr_list, outputs))]
            loss = sum(losses)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader):.4f}")

    # Evaluation
    model.eval()
    all_preds = {attr: [] for attr in attr_list}
    all_labels = {attr: [] for attr in attr_list}

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            for i, (attr, output) in enumerate(zip(attr_list, outputs)):
                _, preds = torch.max(output, 1)
                all_preds[attr].extend(preds.cpu().numpy())
                all_labels[attr].extend(labels[:, i].cpu().numpy())

    # Print metrics for each attribute
    for attr in attr_list:
        print(f"Classification Report for {attr}:")
        print(classification_report(all_labels[attr], all_preds[attr], target_names=label_encoders[attr].classes_))
        print(f"Accuracy for {attr}: {accuracy_score(all_labels[attr], all_preds[attr]) * 100:.2f}%")

In [109]:
csv_path = "/Users/susanketsarkar/Desktop/Code/Meesho/data/cat_wise_csv/Kurtis_data.csv"
category_to_predict = "Kurtis"
img_dir = "../data/train_images"
attr_info_csv = "../data/attr_info.csv"

train_cnn_on_attributes(csv_path, img_dir, attr_info_csv, category_to_predict, epochs=10, batch_size=32)

['color' 'fit_shape' 'length' 'occasion' 'ornamentation' 'pattern'
 'print_or_pattern_type' 'sleeve_length' 'sleeve_styling']
['color', 'fit_shape', 'length', 'occasion', 'ornamentation', 'pattern', 'print_or_pattern_type', 'sleeve_length', 'sleeve_styling']


  0%|          | 0/171 [00:00<?, ?it/s]

Index(['id', 'Category', 'len', 'color', 'fit_shape', 'length', 'occasion',
       'ornamentation', 'pattern', 'print_or_pattern_type', 'sleeve_length',
       'sleeve_styling', 'image_path'],
      dtype='object')
Index(['id', 'Category', 'len', 'color', 'fit_shape', 'length', 'occasion',
       'ornamentation', 'pattern', 'print_or_pattern_type', 'sleeve_length',
       'sleeve_styling', 'image_path'],
      dtype='object')
Index(['id', 'Category', 'len', 'color', 'fit_shape', 'length', 'occasion',
       'ornamentation', 'pattern', 'print_or_pattern_type', 'sleeve_length',
       'sleeve_styling', 'image_path'],
      dtype='object')
Index(['id', 'Category', 'len', 'color', 'fit_shape', 'length', 'occasion',
       'ornamentation', 'pattern', 'print_or_pattern_type', 'sleeve_length',
       'sleeve_styling', 'image_path'],
      dtype='object')
Index(['id', 'Category', 'len', 'color', 'fit_shape', 'length', 'occasion',
       'ornamentation', 'pattern', 'print_or_pattern_type', 'sle

  0%|          | 0/171 [00:00<?, ?it/s]


RuntimeError: linear(): input and weight.T shapes cannot be multiplied (32x8192 and 32768x256)