## Imports

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, SubsetRandomSampler
import numpy as np
  

In [2]:
batch_size = 256
dataset_dir = 'fruits-360'

# Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Normalize the input data using ImageNet statistics
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(imagenet_mean, imagenet_std)
])

# Load the datasets
train_dataset = torchvision.datasets.ImageFolder(
    root=f'{dataset_dir}/train',
    transform=transform
)
val_dataset = torchvision.datasets.ImageFolder(
    root=f'{dataset_dir}/val',
    transform=transform
)
test_dataset = torchvision.datasets.ImageFolder(
    root=f'{dataset_dir}/test',
    transform=transform
)

# Create data loaders
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
train_loader = torch.utils.data.DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
val_loader = torch.utils.data.DataLoader(val_dataset, shuffle=True, batch_size=batch_size)

In [4]:
# Initialize the ResNet-18 model
model = torchvision.models.resnet18(pretrained=False)
num_classes = len(train_dataset.classes)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model.to(device)

print(model)



ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [None]:
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001, momentum=0.9)

In [None]:
# Training loop
num_epochs = 50
min_loss = np.inf
for epoch in range(num_epochs):
    print(f'training... epoch {epoch}')
    running_loss = 0.0
    val_loss = 0.0
    model.train()
    for i, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    model.eval()   
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            vloss = criterion(outputs, labels)
            val_loss += vloss.item()
        
    if val_loss < min_loss:
        min_loss = val_loss
        torch.save(model, 'model.pth')
        print(f'saving model at epoch {epoch}')
        print(f'Epoch {epoch + 1}, Batch {i + 1}: loss {running_loss / 200:.3f} val_loss {val_loss / 200:.3f}')
            

print('Training finished!')



In [None]:
# Evaluation on the test set
model.eval()  # Switch to evaluation mode
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f'Test Accuracy: {accuracy:.2f}%')
