In [None]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.



In [1]:
# Define the paths for your dataset
train_dir = "data/train"
valid_dir = "data/validation"

In [3]:
import os
import timm
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms
from PIL import Image

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
# Define the transformations for the train and validation data
train_transforms = transforms.Compose([
    transforms.Resize((192, 192)),
    transforms.ColorJitter(brightness=0.125, contrast=0.125, saturation=0.05, hue=0.025),
    transforms.RandomAffine(degrees=10, translate=(0.075, 0.075), shear=0.025),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

valid_transforms = transforms.Compose([
    transforms.Resize((192, 192)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

# Load datasets using ImageFolder
train_data = ImageFolder(root=train_dir, transform=train_transforms)
valid_data = ImageFolder(root=valid_dir, transform=valid_transforms)

# DataLoader for train and validation
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=32, shuffle=False)

In [6]:
# Load the ResNet18 model from timm
model = timm.create_model('resnet18d', pretrained=True, num_classes=len(train_data.classes))

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=3.25e-4, weight_decay=1e-3)

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


In [13]:
# Training loop
def train_model(model, train_loader, valid_loader, criterion, optimizer, epochs=16):
    model.train()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    for epoch in range(epochs):
        running_loss = 0.0
        correct = 0
        total = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        accuracy = 100 * correct / total
        print(f'Epoch {epoch + 1}, Loss: {running_loss / len(train_loader)}, Accuracy: {accuracy}%')

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for images, labels in valid_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        val_accuracy = 100 * val_correct / val_total
        print(f'Validation Loss: {val_loss / len(valid_loader)}, Validation Accuracy: {val_accuracy}%')
    torch.save(model.state_dict(), "models/model.pth")
        
        

In [14]:
# Run training
train_model(model, train_loader, valid_loader, criterion, optimizer)

Epoch 1, Loss: 0.20940416968531078, Accuracy: 93.05555555555556%
Validation Loss: 0.10299264639616013, Validation Accuracy: 96.7032967032967%
Epoch 2, Loss: 0.37274056513849485, Accuracy: 86.48504273504274%
Validation Loss: 0.30067891099800664, Validation Accuracy: 86.81318681318682%
Epoch 3, Loss: 0.3205676314413038, Accuracy: 89.39636752136752%
Validation Loss: 0.28633904705444974, Validation Accuracy: 92.3076923076923%
Epoch 4, Loss: 0.17341858360311416, Accuracy: 94.04380341880342%
Validation Loss: 0.11076579242944717, Validation Accuracy: 97.8021978021978%
Epoch 5, Loss: 0.16086507593400967, Accuracy: 93.99038461538461%
Validation Loss: 0.14334202061096826, Validation Accuracy: 93.4065934065934%
Epoch 6, Loss: 0.15890930438191336, Accuracy: 94.33760683760684%
Validation Loss: 0.23347988951718435, Validation Accuracy: 94.50549450549451%
Epoch 7, Loss: 0.14267525693767855, Accuracy: 95.11217948717949%
Validation Loss: 0.09789642800266544, Validation Accuracy: 95.6043956043956%
Epoch