In [None]:
import torch
import wandb
import numpy as np
import matplotlib.pyplot as plt
from torch import nn, optim
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader, Subset
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm

# Initialize device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
DATA_DIR = "inaturalist_12K"
BATCH_SIZE = 64
IMG_SIZE = 224

# Define transformations
train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load dataset
full_dataset = datasets.ImageFolder(f"{DATA_DIR}/train", train_transform)
train_idx, val_idx = train_test_split(
    range(len(full_dataset)), 
    test_size=0.2, 
    stratify=full_dataset.targets
)

# Create dataloaders
train_loader = DataLoader(
    Subset(full_dataset, train_idx),
    batch_size=BATCH_SIZE,
    shuffle=True,
    pin_memory=True,
    num_workers=4
)

val_loader = DataLoader(
    Subset(full_dataset, val_idx),
    batch_size=BATCH_SIZE,
    pin_memory=True,
    num_workers=4
)


In [None]:
def create_resnet_model():
    model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
    
    # Freeze parameters
    for param in model.parameters():
        param.requires_grad = False
        
    # Modify classifier
    model.fc = nn.Sequential(
        nn.Linear(model.fc.in_features, 512),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(512, 10)
    )
    return model.to(device)

model = create_resnet_model()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=1e-4)

In [None]:
wandb.login(key="1b5f670bdb4b8ed39a9bc34744dd738c9b33dede") # WandB API Key
wandb.init(project="DL-Assignment2-ResNet", config={
    "batch_size": BATCH_SIZE,
    "learning_rate": 1e-4,
    "architecture": "ResNet50"
})

#1b5f670bdb4b8ed39a9bc34744dd738c9b33dede

# Training parameters
EPOCHS = 10

for epoch in range(EPOCHS):
    # Training phase
    model.train()
    train_loss = 0.0
    
    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        
        optimizer.zero_grad(set_to_none=True)
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        wandb.log({"train_loss": loss.item()})
    
    # Validation phase
    model.eval()
    val_loss, correct = 0.0, 0
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            
            outputs = model(images)
            val_loss += criterion(outputs, labels).item()
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
    
    # Log metrics
    val_acc = 100 * correct / len(val_loader.dataset)
    wandb.log({
        "epoch": epoch+1,
        "val_loss": val_loss/len(val_loader),
        "val_acc": val_acc
    })


In [None]:
def plot_predictions(model, dataloader, class_names):
    model.eval()
    samples = {i: [] for i in range(10)}
    
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            
            for img, label, pred in zip(images, labels, preds):
                if len(samples[label.item()]) < 3:
                    samples[label.item()].append((
                        img.cpu(), 
                        pred.item()
                    ))
    
    # Plotting
    fig, axes = plt.subplots(10, 3, figsize=(12, 30))
    for cls_idx, (cls_name, examples) in enumerate(samples.items()):
        for ex_idx, (img, pred) in enumerate(examples):
            ax = axes[cls_idx, ex_idx]
            img = img.permute(1, 2, 0).numpy()
            img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
            img = np.clip(img, 0, 1)
            
            ax.imshow(img)
            ax.set_title(f"True: {class_names[cls_idx]}\nPred: {class_names[pred]}")
            ax.axis('off')
    
    plt.tight_layout()
    wandb.log({"predictions": plt})

# Class names
class_names = ["Amphibia", "Animalia", "Arachnida", "Aves", "Fungi", 
              "Insecta", "Mammalia", "Mollusca", "Plantae", "Reptilia"]

# Generate and log predictions
plot_predictions(model, val_loader, class_names)