In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Download ImageWoof dataset
!wget https://s3.amazonaws.com/fast-ai-imageclas/imagewoof2.tgz
!tar -xvzf imagewoof2.tgz

In [None]:
# Imports
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models

In [None]:
# Data transforms
train_tf = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
test_tf = train_tf

In [None]:
# Load datasets
train_ds = datasets.ImageFolder('imagewoof2/train', transform=train_tf)
val_ds = datasets.ImageFolder('imagewoof2/val', transform=test_tf)
train_dl = DataLoader(train_ds, batch_size=32, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=32)

In [None]:
# Model with modified pooling
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
model.avgpool = nn.AdaptiveAvgPool2d((1,1))
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(train_ds.classes))

In [None]:
# Training setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
# Training
for epoch in range(10):
    model.train()
    running_loss = 0
    for x,y in train_dl:
        x,y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out,y)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_dl):.4f}')

In [None]:
# Validation accuracy
model.eval()
correct, total = 0, 0
with torch.no_grad():
    for x,y in val_dl:
        x,y = x.to(device), y.to(device)
        out = model(x)
        _,pred = torch.max(out,1)
        total += y.size(0)
        correct += (pred==y).sum().item()
print(f'Accuracy: {100*correct/total:.2f}%')

In [None]:
# Save model
torch.save(model.state_dict(), 'imagewoof_model.pth')