In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import torchvision.models as models
import os
from PIL import Image
import wandb

# Define the CNN model
class ImageClassifier(nn.Module):
    def __init__(self, num_classes):
        super(ImageClassifier, self).__init__()
        self.model = models.resnet18(pretrained=True)
        num_ftrs = self.model.fc.in_features
        self.model.fc = nn.Linear(num_ftrs, num_classes)

    def forward(self, x):
        return self.model(x)

# Custom dataset class
class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []
        self.classes = os.listdir(root_dir)

        for i, cls in enumerate(self.classes):
            cls_path = os.path.join(root_dir, cls)
            for img_name in os.listdir(cls_path):
                img_path = os.path.join(cls_path, img_name)
                self.images.append(img_path)
                self.labels.append(i)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        img = Image.open(img_path).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, label

# Training function
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=1):
    wandb.watch(model, criterion)
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * images.size(0)
        
        epoch_loss = running_loss / len(train_loader.dataset)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")
        # Log metrics to wandb
        wandb.log({"Training Loss": epoch_loss, "Validation Loss": val_loss, "Validation Accuracy": val_acc})

import os

# Function to save model checkpoints to wandb
def save_checkpoint(model, epoch, optimizer, path='./classifier_checkpoints'):
    os.makedirs(path, exist_ok=True)
    checkpoint_name = f"model_epoch_{epoch}.pt"
    checkpoint_path = os.path.join(path, checkpoint_name)
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, checkpoint_path)
    return checkpoint_path

# Function to load model from wandb
def load_model_from_wandb(run_path, model):
    run = wandb.init()
    artifact = run.use_artifact(run_path + ":latest")
    artifact_dir = artifact.download()
    model_path = os.path.join(artifact_dir, "model.pt")
    model.load_state_dict(torch.load(model_path))
    return model

# Training function with visualization, wandb logging, and model checkpoint saving
def train_model_wandb_checkpoint(model, train_loader, val_loader, criterion, optimizer, num_epochs=10):
    wandb.watch(model, criterion)  # Log gradients and model parameters
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * images.size(0)
            # log run loss to wandb
            wandb.log({"Training Loss": loss.item()})
        
        epoch_loss = running_loss / len(train_loader.dataset)
        
        # Validation loss
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * images.size(0)
                
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        val_loss /= len(val_loader.dataset)
        val_acc = correct / total

        print(f"Epoch [{epoch+1}/{num_epochs}], Training Loss: {epoch_loss:.4f}, Validation Loss: {val_loss:.4f}, Validation Acc: {val_acc:.4f}")

        # Log metrics to wandb
        wandb.log({"Training Loss": epoch_loss, "Validation Loss": val_loss, "Validation Accuracy": val_acc})

        # Save model checkpoint to wandb
        checkpoint_path = save_checkpoint(model, epoch, optimizer)
        wandb.save(checkpoint_path)


# Inference function
def infer_single_image(image_path, model, transform):
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0)  # Add batch dimension
    with torch.no_grad():
        model.eval()
        output = model(image.to(device))
        _, predicted = torch.max(output, 1)
    return predicted.item()

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Initialise wandb
wandb.init(project="day-night-classifier")

# Create datasets and dataloaders
train_dataset = CustomDataset(root_dir='datasets/bdd100k/train', transform=transform)
val_dataset = CustomDataset(root_dir='datasets/bdd100k/test', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Initialize model, loss function, and optimizer
num_classes = len(train_dataset.classes)
model = ImageClassifier(num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train the model
#train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10)
# Usage example
train_model_wandb_checkpoint(model, train_loader, val_loader, criterion, optimizer, num_epochs=1)


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mtharushalekamge-19[0m. Use [1m`wandb login --relogin`[0m to force relogin


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /home/tharusha/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 210MB/s]


KeyboardInterrupt: 

In [None]:
# Inference example for a single image
image_path = 'path/to/single/image.jpg'
predicted_class = infer_single_image(image_path, model, transform)
print(f"Predicted class: {train_dataset.classes[predicted_class]}")