In [None]:
import os
import torch
from torch import nn, save, load
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.optim import Adam
from model import ImageClassifier

In [None]:
EPOCHS = 10

In [None]:
class FlagDataset(Dataset):
  def __init__(self, root_dir, transform=None):
    self.root_dir = root_dir
    self.image_paths = []
    self.labels = []  # Assuming each folder name represents the label

    for country_dir in os.listdir(root_dir):
      label = country_dir  # Assuming folder name is the label
      country_path = os.path.join(root_dir, country_dir)
      for image_file in os.listdir(country_path):
        image_path = os.path.join(country_path, image_file)
        self.image_paths.append(image_path)
        self.labels.append(label)
    
    self.transform = transform

  def __len__(self):
    return len(self.image_paths)

  def __getitem__(self, idx):
    image_path = self.image_paths[idx]
    label = self.labels[idx]
    image = Image.open(image_path).convert('RGB')  # Assuming RGB format
    if self.transform:
      image = self.transform(image)
    return image, label

In [None]:
data_dir = 'data'

transform = transforms.Compose([
    transforms.Resize((20, 30)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = FlagDataset(data_dir, transform=transform)
train_data = DataLoader(dataset, batch_size=32, shuffle=True)

clf = ImageClassifier().to('cuda')
opt = Adam(clf.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss() 

In [None]:

# Training flow 
if __name__ == "__main__": 
    for epoch in range(EPOCHS):
        total_batches = len(train_data)  # Get total number of batches

        for batch_idx, batch in enumerate(train_data):
            X, y = batch
            X, y = X.to('cuda'), y.to('cuda')
            yhat = clf(X)
            loss = loss_fn(yhat, y)

            print(f"Epoch: {epoch} | Batch: {batch_idx+1}/{total_batches} | Loss: {loss.item():.4f}")

            # Apply backprop
            opt.zero_grad()
            loss.backward()
            opt.step()

        print(f"Epoch:{epoch} loss is {loss.item()}")
    
    with open('model_state.pt', 'wb') as f: 
        save(clf.state_dict(), f) 

    with open('model_state.pt', 'rb') as f: 
        clf.load_state_dict(load(f))  

    img = Image.open('img_3.jpg') 
    img_tensor = ToTensor()(img).unsqueeze(0).to('cuda')

    print(torch.argmax(clf(img_tensor)))