In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Transforms (match what ResNet expects)
data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # ImageNet mean
                         std=[0.229, 0.224, 0.225])    # ImageNet std
])

# Load dataset using ImageFolder
root_dir = '/content/drive/MyDrive/dataset/train'
dataset = ImageFolder(root=root_dir, transform=data_transforms)

# Stratified split based on labels
targets = dataset.targets
train_val_indices, test_indices = train_test_split(
    list(range(len(dataset))), test_size=0.2, stratify=targets, random_state=42)
train_indices, val_indices = train_test_split(
    train_val_indices, test_size=0.2, stratify=[targets[i] for i in train_val_indices], random_state=42)

# Create DataLoaders
train_loader = DataLoader(Subset(dataset, train_indices), batch_size=32, shuffle=True)
val_loader = DataLoader(Subset(dataset, val_indices), batch_size=32, shuffle=False)
test_loader = DataLoader(Subset(dataset, test_indices), batch_size=32, shuffle=False)

# Print sizes
print(f"Training set size: {len(train_indices)}")
print(f"Validation set size: {len(val_indices)}")
print(f"Testing set size: {len(test_indices)}")


Training set size: 2700
Validation set size: 676
Testing set size: 844


In [None]:
# Load pretrained ResNet50 model
resnet = models.resnet50(pretrained=True)

# Freeze all layers (optional: change this if you want fine-tuning)
for param in resnet.parameters():
    param.requires_grad = False

# Modify the final fully connected layer
num_ftrs = resnet.fc.in_features
resnet.fc = nn.Linear(num_ftrs, 4)  # 4 output classes for your problem

# Move model to device
model = resnet.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=0.001)  # Only train the final layer




In [None]:
from tqdm import tqdm  # for progress bars
import torch

# Number of epochs
num_epochs = 10
start_epoch = 0  # will update if resuming

checkpoint_path = "/content/drive/MyDrive/resnet50_checkpoint.pth"

# To resume from checkpoint if exists
if os.path.exists(checkpoint_path):
    print("Loading checkpoint...")
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']
    print(f"Resuming from epoch {start_epoch}")

for epoch in range(start_epoch, num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    train_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training", leave=False)
    for images, labels in train_bar:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        train_acc = 100 * correct / total
        train_bar.set_postfix(loss=running_loss / total, accuracy=f"{train_acc:.2f}%")

    avg_loss = running_loss / len(train_loader)
    train_accuracy = 100 * correct / total

    # Validation loop
    model.eval()
    val_correct = 0
    val_total = 0
    val_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Validation", leave=False)
    with torch.no_grad():
        for images, labels in val_bar:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()

            val_acc = 100 * val_correct / val_total
            val_bar.set_postfix(accuracy=f"{val_acc:.2f}%")

    val_accuracy = 100 * val_correct / val_total

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, "
          f"Train Acc: {train_accuracy:.2f}%, Val Acc: {val_accuracy:.2f}%")

    # Save checkpoint after each epoch
    torch.save({
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': avg_loss,
    }, checkpoint_path)
    print(f"Checkpoint saved at epoch {epoch+1}")


Loading checkpoint...
Resuming from epoch 6




Epoch [7/10], Loss: 0.4016, Train Acc: 84.37%, Val Acc: 85.21%
Checkpoint saved at epoch 7




Epoch [8/10], Loss: 0.3956, Train Acc: 85.48%, Val Acc: 86.24%
Checkpoint saved at epoch 8




Epoch [9/10], Loss: 0.3883, Train Acc: 85.33%, Val Acc: 85.80%
Checkpoint saved at epoch 9




Epoch [10/10], Loss: 0.3745, Train Acc: 86.04%, Val Acc: 86.39%
Checkpoint saved at epoch 10


In [None]:
# Save the final model weights
final_model_path = "/content/drive/MyDrive/resnet50_final.pth"
torch.save(model.state_dict(), final_model_path)
print(f"Final model saved to {final_model_path}")


Final model saved to /content/drive/MyDrive/resnet50_final.pth


In [None]:
# Save the entire model
full_model_path = "/content/drive/MyDrive/resnet50_full_model.pth"
torch.save(model, full_model_path)
print(f"Entire model (with architecture) saved to {full_model_path}")

Entire model (with architecture) saved to /content/drive/MyDrive/resnet50_full_model.pth


In [None]:
# Evaluate the model on test data
model.eval()
test_correct = 0
test_total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        test_total += labels.size(0)
        test_correct += (predicted == labels).sum().item()

test_accuracy = 100 * test_correct / test_total
print(f"Test Accuracy: {test_accuracy:.2f}%")

Test Accuracy: 85.31%
