In [1]:
!pip install timm
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import timm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Prepare the dataset
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to match the input size expected by EfficientNet-B0
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=32, shuffle=True)

testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = DataLoader(testset, batch_size=32, shuffle=False)

# Create the model
model = timm.create_model('efficientnet_b0', pretrained=True, num_classes=10)  # Output will be predictions for 10 classes
model = model.to(device)

# Compile the model
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()

# Train the model for 9 epochs
for epoch in range(9):
    model.train()
    train_loss = 0.0
    train_correct = 0
    total = 0
    for inputs, labels in trainloader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        train_correct += (predicted == labels).sum().item()
    train_loss = train_loss / len(trainloader)
    train_accuracy = train_correct / total

    print(f'Epoch {epoch+1}:')
    print(f'Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}\n')

# Saving the model
torch.save(model.state_dict(), 'model.pt')

# Load the best model
model.load_state_dict(torch.load('model.pt'))

Collecting timm
  Downloading timm-0.9.12-py3-none-any.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m15.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: timm
Successfully installed timm-0.9.12
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:05<00:00, 29394554.33it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


model.safetensors:   0%|          | 0.00/21.4M [00:00<?, ?B/s]

Epoch 1:
Train Loss: 0.4369, Train Accuracy: 0.8590

Epoch 2:
Train Loss: 0.2448, Train Accuracy: 0.9169

Epoch 3:
Train Loss: 0.2023, Train Accuracy: 0.9308

Epoch 4:
Train Loss: 0.1691, Train Accuracy: 0.9429

Epoch 5:
Train Loss: 0.1382, Train Accuracy: 0.9532

Epoch 6:
Train Loss: 0.1234, Train Accuracy: 0.9576

Epoch 7:
Train Loss: 0.1047, Train Accuracy: 0.9641

Epoch 8:
Train Loss: 0.0892, Train Accuracy: 0.9700

Epoch 9:
Train Loss: 0.0840, Train Accuracy: 0.9712



<All keys matched successfully>

In [2]:
# Evaluate the model on the test set
model.eval()
test_loss = 0.0
test_correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in testloader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        test_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        test_correct += (predicted == labels).sum().item()
test_loss = test_loss / len(testloader)
test_accuracy = test_correct / total

print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}\n')

Test Loss: 0.2402, Test Accuracy: 0.9294

