In [1]:
print("File started")

File started


In [2]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("wandb_api_key")

!pip3 install wandb
import wandb
wandb.login(key=secret_value_0)



[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mpratham3992[0m ([33mpratham3992-plaksha[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
print("Wandb login done")

Wandb login done


In [4]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import models, transforms
from PIL import Image
from tqdm import tqdm
import time
import copy
import wandb  # Import wandb

device = torch.device('cuda' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu'))
print(f"Using device: {device}")

class ProductDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None, balanced_subset=False, max_per_class=5):
        self.df = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform
        
        if balanced_subset:
            self.df = self._create_balanced_subset(max_per_class)
        
        self.num_classes = len(self.df['class'].unique())
        
    def _create_balanced_subset(self, max_per_class):
        balanced_df = pd.DataFrame()
        
        for class_id in self.df['class'].unique():
            class_df = self.df[self.df['class'] == class_id]
            if len(class_df) > max_per_class:
                sampled_df = class_df.sample(max_per_class, random_state=42)
                balanced_df = pd.concat([balanced_df, sampled_df])
            else:
                balanced_df = pd.concat([balanced_df, class_df])
        
        return balanced_df
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        img_name = os.path.join(self.img_dir, self.df.iloc[idx, 0])
        image = Image.open(img_name).convert('RGB')
        class_id = self.df.iloc[idx, 1]
        group_id = self.df.iloc[idx, 2] if 'group' in self.df.columns else -1
        
        sample = {'image': image, 'class': class_id, 'group': group_id, 'filename': self.df.iloc[idx, 0]}
        
        if self.transform:
            sample['image'] = self.transform(sample['image'])
        
        return sample

def get_data_transforms(high_res=False):
    resize_size = 512 if high_res else 256
    crop_size = 448 if high_res else 224
    
    train_transform = transforms.Compose([
        transforms.Resize((resize_size, resize_size)),
        transforms.RandomCrop(crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    val_transform = transforms.Compose([
        transforms.Resize((resize_size, resize_size)),
        transforms.CenterCrop(crop_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    return train_transform, val_transform

class TripletLoss(nn.Module):
    def __init__(self, margin=0.3):
        super(TripletLoss, self).__init__()
        self.margin = margin
    
    def forward(self, anchor, positive, negative):
        pos_dist = torch.sum((anchor - positive) ** 2, dim=1)
        neg_dist = torch.sum((anchor - negative) ** 2, dim=1)
        loss = torch.mean(torch.clamp(pos_dist - neg_dist + self.margin, min=0))
        return loss

class FeatureExtractor(nn.Module):
    def __init__(self, num_classes, embedding_dim=1536, pretrained=True):
        super(FeatureExtractor, self).__init__()
        
        model = models.efficientnet_b1(weights = 'DEFAULT' if pretrained else None)
        
        self.features = nn.Sequential(*list(model.children())[:-1])
        
        in_features = model.classifier[1].in_features
        self.embedding = nn.Linear(in_features, embedding_dim)
        self.fc = nn.Linear(embedding_dim, num_classes)
        
    def forward(self, x):
        features = self.features(x)
        features = features.flatten(start_dim=1)
        
        embeddings = self.embedding(features)
        classifier = self.fc(embeddings)
        
        return embeddings, classifier
    
    def extract_features(self, x):
        features = self.features(x)
        features = features.flatten(start_dim=1)
        embeddings = self.embedding(features)
        return embeddings

class AccuracyLoss(nn.Module):
    def __init__(self):
        super(AccuracyLoss, self).__init__()
    
    def forward(self, outputs, targets):
        if targets.dim() == 1:
            targets = torch.nn.functional.one_hot(targets, num_classes=outputs.size(1)).float()
        
        correct_probs = (outputs * targets).sum(dim=1)
        loss = 1.0 - torch.mean(correct_probs)
        return loss

def train_model(model, dataloaders, criterion, optimizer, scheduler, num_epochs=25, use_metric_loss=False):
    since = time.time()
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)
        
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()
            
            running_loss = 0.0
            running_corrects = 0
            
            pbar = tqdm(dataloaders[phase], desc=phase)
            for batch in pbar:
                inputs = batch['image'].to(device)
                labels = batch['class'].to(device)
                
                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == 'train'):
                    embeddings, outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    
                    if not use_metric_loss:
                        loss = criterion(outputs, labels)
                    else:
                        probs = nn.functional.softmax(outputs, dim=1)
                        loss = criterion(probs, labels)
                    
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                
                # Update tqdm with current loss
                pbar.set_postfix(loss=f"{loss.item():.4f}")
                
                # Log batch-level metrics to wandb
                wandb.log({f"{phase}_batch_loss": loss.item()})
            
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
            
            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
            
            # Log epoch-level metrics to wandb
            wandb.log({
                f"{phase}_loss": epoch_loss,
                f"{phase}_accuracy": epoch_acc,
                "epoch": epoch
            })
            
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                
                # Log best model checkpoint to wandb
                torch.save(model.state_dict(), "best_model.pth")
                wandb.save("best_model.pth")
        
        if scheduler is not None:
            scheduler.step()
            # Log learning rate
            wandb.log({"learning_rate": optimizer.param_groups[0]['lr'], "epoch": epoch})
        
        print()
    
    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:.4f}')
    
    # Log final best accuracy
    wandb.log({"best_val_accuracy": best_acc})
    
    model.load_state_dict(best_model_wts)
    return model

def finetune_model(model, balanced_dataloader, criterion, optimizer, num_epochs=3, use_metric_loss=False):
    print("Starting finetuning on balanced subset...")
    since = time.time()
    
    model.train()
    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)
        
        running_loss = 0.0
        running_corrects = 0
        
        pbar = tqdm(balanced_dataloader, desc="Finetuning")
        for batch in pbar:
            inputs = batch['image'].to(device)
            labels = batch['class'].to(device)
            
            optimizer.zero_grad()
            
            embeddings, outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            
            if not use_metric_loss:
                loss = criterion(outputs, labels)
            else:
                probs = nn.functional.softmax(outputs, dim=1)
                loss = criterion(probs, labels)
            
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
            
            # Update tqdm with current loss
            pbar.set_postfix(loss=f"{loss.item():.4f}")
            
            # Log batch-level metrics to wandb
            wandb.log({"finetune_batch_loss": loss.item()})
        
        epoch_loss = running_loss / len(balanced_dataloader.dataset)
        epoch_acc = running_corrects.double() / len(balanced_dataloader.dataset)
        
        print(f'Finetune Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
        
        # Log epoch-level metrics to wandb
        wandb.log({
            "finetune_loss": epoch_loss,
            "finetune_accuracy": epoch_acc,
            "finetune_epoch": epoch
        })
    
    time_elapsed = time.time() - since
    print(f'Finetuning complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    
    return model

def main():
    # Set parameters directly here instead of using argparse
    train_dir = '/kaggle/input/visual-product-recognition/train/train'
    train_csv = '/kaggle/input/visual-product-recognition/train.csv'
    batch_size = 16
    epochs = 10
    high_res = False
    balanced_finetune = True
    metric_loss = False
    lr = 3e-4
    finetune_lr = 3e-5
    embedding_dim = 1280
    weight_decay = 1e-4
    
    # Initialize wandb
    wandb.init(
        project="visual-product-recognition",
        config={
            "model": "efficientnet_b1",
            "epochs": epochs,
            "batch_size": batch_size,
            "learning_rate": lr,
            "finetune_learning_rate": finetune_lr,
            "embedding_dim": embedding_dim,
            "weight_decay": weight_decay,
            "high_res": high_res,
            "balanced_finetune": balanced_finetune,
            "metric_loss": metric_loss
        }
    )
    
    # Get data transformations
    train_transform, val_transform = get_data_transforms(high_res=high_res)
    
    # Load datasets
    train_dataset = ProductDataset(train_csv, train_dir, transform=train_transform)
    
    # Create a validation split from the training data
    train_size = int(0.8 * len(train_dataset))
    val_size = len(train_dataset) - train_size
    train_subset, val_subset = random_split(train_dataset, [train_size, val_size])
    
    # Update validation subset transform
    val_subset.dataset.transform = val_transform
    
    # Create data loaders for training and validation
    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, num_workers=4)
    
    dataloaders = {'train': train_loader, 'val': val_loader}
    
    # Create balanced dataset for finetuning if specified
    if balanced_finetune:
        balanced_dataset = ProductDataset(
            train_csv, train_dir, transform=train_transform, 
            balanced_subset=True, max_per_class=5
        )
        balanced_loader = DataLoader(balanced_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    
    # Get the number of classes from the dataset
    num_classes = train_dataset.num_classes
    print(f"Number of classes: {num_classes}")
    wandb.config.update({"num_classes": num_classes})
    
    # Initialize the model
    model = FeatureExtractor(num_classes=num_classes, embedding_dim=embedding_dim)
    model = model.to(device)
    
    # Log model architecture to wandb
    wandb.watch(model)
    
    # Define loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    
    # If using metric-guided loss, use our custom AccuracyLoss
    if metric_loss:
        metric_criterion = AccuracyLoss()
    
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    
    # Learning rate scheduler
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20, 40, 60], gamma=0.1)
    
    # Train the model
    print("Starting model training...")
    model = train_model(
        model, dataloaders, criterion, optimizer, scheduler,
        num_epochs=epochs, use_metric_loss=False
    )
    
    # Finetune on balanced subset if specified
    if balanced_finetune:
        print("Finetuning on balanced subset...")
        finetune_optimizer = optim.Adam(model.parameters(), lr=finetune_lr, weight_decay=weight_decay)
        
        if metric_loss:
            model = finetune_model(
                model, balanced_loader, metric_criterion, finetune_optimizer, 
                num_epochs=1, use_metric_loss=True
            )
        else:
            model = finetune_model(
                model, balanced_loader, criterion, finetune_optimizer, 
                num_epochs=3, use_metric_loss=False
            )
    
    # Save the trained model
    torch.save(model.state_dict(), 'product_model.pth')
    wandb.save('product_model.pth')
    
    # Finish wandb run
    wandb.finish()
    
    print("Done!")

if __name__ == "__main__":
    main()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Using device: cuda


[34m[1mwandb[0m: Tracking run with wandb version 0.19.6
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20250510_112630-kwb6ivst[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mpolar-totem-3[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/pratham3992-plaksha/visual-product-recognition[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/pratham3992-plaksha/visual-product-recognition/runs/kwb6ivst[0m


Number of classes: 9691


Downloading: "https://download.pytorch.org/models/efficientnet_b1-c27df63c.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b1-c27df63c.pth
100%|██████████| 30.1M/30.1M [00:00<00:00, 197MB/s]


Starting model training...
Epoch 1/10
----------


train: 100%|██████████| 7097/7097 [15:20<00:00,  7.71it/s, loss=2.6591]


train Loss: 5.1097 Acc: 0.2300


val: 100%|██████████| 1775/1775 [03:19<00:00,  8.90it/s, loss=0.9132]


val Loss: 3.0046 Acc: 0.4365

Epoch 2/10
----------


train: 100%|██████████| 7097/7097 [14:08<00:00,  8.37it/s, loss=3.6372]


train Loss: 2.4062 Acc: 0.5180


val: 100%|██████████| 1775/1775 [03:05<00:00,  9.55it/s, loss=0.2114]


val Loss: 2.4139 Acc: 0.5347

Epoch 3/10
----------


train: 100%|██████████| 7097/7097 [14:12<00:00,  8.33it/s, loss=3.4726]


train Loss: 1.7861 Acc: 0.6178


val: 100%|██████████| 1775/1775 [03:09<00:00,  9.38it/s, loss=0.4449]


val Loss: 2.3192 Acc: 0.5563

Epoch 4/10
----------


train: 100%|██████████| 7097/7097 [14:25<00:00,  8.20it/s, loss=1.8240]


train Loss: 1.4998 Acc: 0.6686


val: 100%|██████████| 1775/1775 [03:09<00:00,  9.37it/s, loss=0.5384]


val Loss: 2.3283 Acc: 0.5652

Epoch 5/10
----------


train: 100%|██████████| 7097/7097 [14:31<00:00,  8.14it/s, loss=3.3297]


train Loss: 1.3318 Acc: 0.7005


val: 100%|██████████| 1775/1775 [03:04<00:00,  9.64it/s, loss=0.7799]


val Loss: 2.3542 Acc: 0.5723

Epoch 6/10
----------


train: 100%|██████████| 7097/7097 [14:18<00:00,  8.26it/s, loss=1.9090]


train Loss: 1.2078 Acc: 0.7234


val: 100%|██████████| 1775/1775 [03:10<00:00,  9.31it/s, loss=0.0168]


val Loss: 2.3702 Acc: 0.5760

Epoch 7/10
----------


train: 100%|██████████| 7097/7097 [14:28<00:00,  8.17it/s, loss=2.1659]


train Loss: 1.1255 Acc: 0.7405


val: 100%|██████████| 1775/1775 [03:07<00:00,  9.46it/s, loss=0.1685]


val Loss: 2.4364 Acc: 0.5653

Epoch 8/10
----------


train: 100%|██████████| 7097/7097 [14:16<00:00,  8.29it/s, loss=2.1762]


train Loss: 1.0617 Acc: 0.7528


val: 100%|██████████| 1775/1775 [03:04<00:00,  9.64it/s, loss=0.9039]


val Loss: 2.4607 Acc: 0.5734

Epoch 9/10
----------


train: 100%|██████████| 7097/7097 [14:16<00:00,  8.29it/s, loss=0.7482]


train Loss: 0.9995 Acc: 0.7659


val: 100%|██████████| 1775/1775 [03:04<00:00,  9.60it/s, loss=3.4225]


val Loss: 2.4957 Acc: 0.5694

Epoch 10/10
----------


train: 100%|██████████| 7097/7097 [14:28<00:00,  8.17it/s, loss=2.9500]


train Loss: 0.9635 Acc: 0.7742


val: 100%|██████████| 1775/1775 [03:01<00:00,  9.78it/s, loss=0.0310]


val Loss: 2.5110 Acc: 0.5718

Training complete in 175m 47s
Best val Acc: 0.5760
Finetuning on balanced subset...
Starting finetuning on balanced subset...
Epoch 1/3
----------


Finetuning: 100%|██████████| 2863/2863 [07:12<00:00,  6.62it/s, loss=0.8624]


Finetune Loss: 1.8381 Acc: 0.6523
Epoch 2/3
----------


Finetuning: 100%|██████████| 2863/2863 [07:05<00:00,  6.73it/s, loss=4.7554]


Finetune Loss: 1.0469 Acc: 0.7925
Epoch 3/3
----------


Finetuning: 100%|██████████| 2863/2863 [07:03<00:00,  6.75it/s, loss=2.2463]


Finetune Loss: 0.7329 Acc: 0.8496
Finetuning complete in 21m 22s


[34m[1mwandb[0m: uploading product_model.pth; uploading output.log; uploading wandb-summary.json; uploading config.yaml; uploading best_model.pth
[34m[1mwandb[0m: uploading product_model.pth; uploading wandb-summary.json; uploading best_model.pth
[34m[1mwandb[0m: uploading product_model.pth; uploading best_model.pth
[34m[1mwandb[0m: uploading product_model.pth
[34m[1mwandb[0m:                                                                                
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run history:
[34m[1mwandb[0m:   best_val_accuracy ▁
[34m[1mwandb[0m:               epoch ▁▁▁▂▂▂▃▃▃▃▃▃▄▄▄▅▅▅▆▆▆▆▆▆▇▇▇███
[34m[1mwandb[0m:   finetune_accuracy ▁▆█
[34m[1mwandb[0m: finetune_batch_loss ▄▅▇█▃▅▂▃▃▄▄▅▃▂▂▃▂▃▂▃▃▃▂▃▃▃▄▄▂▂▁▁▃▂▁▃▂▃▂▁
[34m[1mwandb[0m:      finetune_epoch ▁▅█
[34m[1mwandb[0m:       finetune_loss █▃▁
[34m[1mwandb[0m:       learning_rate ▁▁▁▁▁▁▁▁▁▁
[34m[1mwandb[0m:      train_accuracy ▁▅▆▇▇▇████
[34m[1mwandb[0m:    train_batch_loss █▇▆

Done!
