In [12]:
import numpy as np
import PIL

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.models as models
import torchvision.transforms as transforms

import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader

import os

In [13]:
torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [14]:
FOLDERS=["default","no_bg","random_bg"]

In [15]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [16]:
class DataProcess():
    def __init__(self, label_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(label_file, header=None)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = PIL.Image.open(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

In [17]:
alexnet = models.alexnet()
alexnet.to(device=device)
alexnet.train()
loss_function = nn.CrossEntropyLoss()
optimizer = optim.SGD(alexnet.parameters(), lr=0.003, momentum=0.9)

In [18]:
def reset():
    global alexnet, loss_function, optimizer
    alexnet = models.alexnet()
    alexnet.to(device=device)
    alexnet.train()
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.SGD(alexnet.parameters(), lr=0.003, momentum=0.9)

In [19]:
def test(best, name, test_loader):
    test=models.alexnet()
    test.to(device)
    test.load_state_dict(torch.load("./models/alexnet/"+name+"_last.pt", weights_only=True))
    test.eval()
    total=0
    correct=0
    with torch.no_grad():
        for i, data in enumerate(test_loader):
            images_test, labels_test = data
            images_test, labels_test = images_test.to(device), labels_test.to(device)
            outputs=test(images_test)
            _, predicts = torch.max(outputs, 1)

            total += labels_test.size(0)

            correct += (predicts==labels_test).sum().item()
        
        accuracy = 100*correct/total
        
        if best < accuracy:
            torch.save(test.state_dict(), "./models/alexnet/"+name+"_best.pt")
        
        print(f'Accuracy: {accuracy}%')

    return accuracy

In [20]:
def train(name, train_loader):
    running_loss = 0.0
    for i, data in enumerate(train_loader):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        output = alexnet(inputs)
        # print(output[0].size())
        loss = loss_function(output, labels)
        loss.backward()
        optimizer.step()

    running_loss += loss.item()

    print(f'Loss: {running_loss / len(train_loader): .4f}')

    torch.save(alexnet.state_dict(), "./models/alexnet/"+name+"_last.pt")

In [21]:
for name in FOLDERS:
    reset()
    train_data=DataProcess(label_file="./dataset/"+name+"/labels/TRAIN.csv",img_dir="./dataset/"+name+"/TRAIN/",transform=transform)
    test_data=DataProcess(label_file="./dataset/"+name+"/labels/TEST.csv",img_dir="./dataset/"+name+"/TEST/",transform=transform)
    train_loader= torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True, num_workers=0)
    test_loader= torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=False, num_workers=0)
    best=0
    EPOCH=60
    for epoch in range(EPOCH):
        print(f'Training epoch {epoch}...')

        train( name, train_loader)

        last=test(best, name, test_loader)
        if best < last:
            best=last

Training epoch 0...
Loss:  0.1043
Accuracy: 26.726726726726728%
Training epoch 1...
Loss:  0.0131
Accuracy: 32.972972972972975%
Training epoch 2...
Loss:  0.0348
Accuracy: 58.97897897897898%
Training epoch 3...
Loss:  0.0011
Accuracy: 72.61261261261261%
Training epoch 4...
Loss:  0.0027
Accuracy: 81.32132132132132%
Training epoch 5...
Loss:  0.0025
Accuracy: 82.58258258258259%
Training epoch 6...
Loss:  0.0001
Accuracy: 81.56156156156156%
Training epoch 7...
Loss:  0.0000
Accuracy: 84.98498498498499%
Training epoch 8...
Loss:  0.0000
Accuracy: 84.74474474474475%
Training epoch 9...
Loss:  0.0017
Accuracy: 83.18318318318319%
Training epoch 10...
Loss:  0.0000
Accuracy: 88.16816816816817%
Training epoch 11...
Loss:  0.0000
Accuracy: 90.2102102102102%
Training epoch 12...
Loss:  0.0000
Accuracy: 87.62762762762763%
Training epoch 13...
Loss:  0.0315
Accuracy: 76.45645645645645%
Training epoch 14...
Loss:  0.0005
Accuracy: 78.37837837837837%
Training epoch 15...
Loss:  0.0009
Accuracy: 75.7

In [24]:
validate_defult=DataProcess(label_file="./dataset/default/labels/VALIDATE.csv",img_dir="./dataset/default/VALIDATE/",transform=transform)
validate_defult_loader= torch.utils.data.DataLoader(validate_defult, batch_size=10, shuffle=False, num_workers=0)
validate_no_bg=DataProcess(label_file="./dataset/no_bg/labels/VALIDATE.csv",img_dir="./dataset/no_bg/VALIDATE/",transform=transform)
validate_no_bg_loader= torch.utils.data.DataLoader(validate_no_bg, batch_size=10, shuffle=False, num_workers=0)
validate_random_bg=DataProcess(label_file="./dataset/random_bg/labels/VALIDATE.csv",img_dir="./dataset/random_bg/VALIDATE/",transform=transform)
validate_random_bg_loader= torch.utils.data.DataLoader(validate_random_bg, batch_size=10, shuffle=False, num_workers=0)

validate_loaders=[validate_defult_loader,validate_no_bg_loader,validate_random_bg_loader]

In [29]:
def validate(loader: torch.utils.data.DataLoader):
    validate_model = models.alexnet()
    validate_model.to(device)
    validate_model.load_state_dict(torch.load("./models/alexnet/random_bg_best.pt", weights_only=True))
    validate_model.eval()
    total=0
    correct=0
    with torch.no_grad():
        for data in loader:
            image, label =data
            image, label = image.to(device), label.to(device)
            outputs=validate_model(image)
            _, predicts = torch.max(outputs, 1)

            total += label.size(0)

            correct += (predicts==label).sum().item()
        
        accuracy = 100*correct/total
        print(f'Accuracy: {accuracy: .4f}')

In [30]:
for loader in validate_loaders:
    validate(loader)

Accuracy:  64.8488
Accuracy:  74.6437
Accuracy:  94.4877
