In [None]:
import timm
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.datasets import ImageFolder
from tqdm import tqdm
import os

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
num_countries = 6
batch_size = 32
lr = 0.001
num_epochs = 10

model = timm.create_model('efficientnet_b2', pretrained=True)
model.classifier = nn.Sequential(
    nn.Dropout(p=0.3),
    nn.Linear(model.classifier.in_features, num_countries)
)
model = model.to(device)

In [None]:
transforms = transforms.Compose([
    transforms.Resize((260, 260)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
train_dataset = ImageFolder(root='data/train', transform=transforms) # change 'data/train' to your actual training data path
val_dataset = ImageFolder(root='data/val', transform=transforms) # change 'data/val' to your actual validation data path

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

In [None]:
optimizer = Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# training
for epoch in range(num_epochs):
    train_loss = 0
    train_correct = 0
    total = 0

    model.train()

    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        train_correct += (predicted == labels).sum().item()
    
    train_acc = 100. * train_correct / total
    print(f"Train Loss: {train_loss/len(train_loader):.4f}, Train Accuracy: {train_acc:.2f}%")

    # validation
    val_loss = 0
    val_correct = 0
    val_total = 0
    model.eval()
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = outputs.max(1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()

    val_acc = 100. * val_correct / val_total
    print(f"Validation Loss: {val_loss/len(val_loader):.4f}, Validation Accuracy: {val_acc:.2f}%")