In [2]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image


In [4]:
class load_dataset(Dataset):

    def __init__(self, root_dir):
        self.classes = os.listdir(root_dir)
        self.labels = []
        self.images = []
        for i, em in enumerate(self.classes):
            em_dir = os.path.join(root_dir, em)
            for img in os.listdir(em_dir):
                img_path = os.path.join(em_dir, img)
                image = Image.open(img_path)
                img_array = np.array(image)
                img_tensor = torch.from_numpy(img_array) / 255.0
                img_tensor = img_tensor.unsqueeze(0)
                self.images.append(img_tensor)
                self.labels.append(i)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        return self.images[idx], self.labels[idx]


In [9]:
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 16, 3, padding=1), nn.BatchNorm2d(16), nn.ReLU(),
            nn.Conv2d(16, 16, 3, padding=1), nn.BatchNorm2d(16), nn.ReLU(),  
            nn.MaxPool2d(2),
            
            nn.Conv2d(16, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(),  
            nn.MaxPool2d(2),
            
            nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2),
            
            nn.Flatten(), 
            nn.Dropout(0.5), 
            nn.Linear(64 * 6 * 6, 256), nn.BatchNorm1d(256), nn.ReLU(), 
            nn.Dropout(0.5),
            nn.Linear(256, 7), 
        )
    
    def forward(self, x):
        return self.net(x)

In [6]:
train_dataset = load_dataset("../data/train")
test_dataset = load_dataset("../data/test")

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [11]:
model = CNN()

batch_size=64
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
criterion = nn.CrossEntropyLoss()

for i in range(10):
    model.train()
    for images, labels in train_loader:
        optimizer.zero_grad()
        output = model(images)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
    
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            output = model(images)
            _, pred = torch.max(output, 1)
            correct += (pred == labels).sum().item()
            total += labels.size(0)
    
    scheduler.step()
    res = 100*correct/total
    print(f"epoch {i +1}, acc: {res:.2f}%")

epoch 1, acc: 48.97%
epoch 2, acc: 53.79%
epoch 3, acc: 50.29%
epoch 4, acc: 55.98%
epoch 5, acc: 52.70%
epoch 6, acc: 59.38%
epoch 7, acc: 59.14%
epoch 8, acc: 59.61%
epoch 9, acc: 59.61%
epoch 10, acc: 60.23%


In [12]:
best_acc=60.23

In [15]:
optimizer = optim.Adam(model.parameters(), lr=0.0015)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

for epoch in range(10, 15):
    model.train()
    
    for images, labels in train_loader:
        optimizer.zero_grad()
        output = model(images)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
    
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in test_loader:
            output = model(images)
            _, pred = torch.max(output, 1)
            correct += (pred == labels).sum().item()
            total += labels.size(0)
    
    scheduler.step()
    res = 100 * correct / total
    print(f"epoch {epoch+1}, acc: {res:.2f}%")

    if res > best_acc:
        best_acc = res
        torch.save(model.state_dict(), f"../best_models/model_checkpoint_{res:.2f}.pth")

epoch 11, acc: 58.01%
epoch 12, acc: 59.17%
epoch 13, acc: 58.87%
epoch 14, acc: 60.62%
epoch 15, acc: 60.21%
