In [1]:
import numpy as np
import PIL

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.models as models
import torchvision.transforms as transforms

import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader

import os
import csv

In [2]:
torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
FOLDERS=["default","no_bg","random_bg"]

In [4]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [5]:
class DataProcess():
    def __init__(self, label_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(label_file, header=None)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = PIL.Image.open(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

In [6]:
def test(best, name, test_loader):
    test=models.alexnet()
    test.to(device)
    test.load_state_dict(torch.load("./models/alexnet/"+name+"_Adam_last.pt", weights_only=True))
    test.eval()
    total=0
    correct=0
    with torch.no_grad():
        for i, data in enumerate(test_loader):
            images_test, labels_test = data
            images_test, labels_test = images_test.to(device), labels_test.to(device)
            outputs=test(images_test)
            _, predicts = torch.max(outputs, 1)

            total += labels_test.size(0)

            correct += (predicts==labels_test).sum().item()
        
        accuracy = 100*correct/total
        
        if best < accuracy:
            print(f"{best: .4f} < {accuracy: .4f}")
            torch.save(test.state_dict(), "./models/alexnet/"+name+"_Adam_best.pt")
        
        print(f'Accuracy: {accuracy: .4f}%')

    return accuracy

In [7]:
def train(name, train_loader, alexnet, loss_function, optimizer):
    running_loss = 0.0
    for i, data in enumerate(train_loader):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        output = alexnet(inputs)
        # print(output[0].size())
        loss = loss_function(output, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f'Loss: {running_loss / len(train_loader): .4f}')

    torch.save(alexnet.state_dict(), "./models/alexnet/"+name+"_Adam_last.pt")
    return (running_loss / len(train_loader))

In [8]:
def write_csv(name, data):
    with open("./models/alexnet/"+name+"_dataset_Adam.csv",'w+',newline='') as file:
        writer=csv.writer(file)
        writer.writerows(data)

In [9]:
for name in FOLDERS:
    torch.cuda.empty_cache()
    if ('alexnet' in locals() or 'alexnet' in globals()) and ('loss_function' in locals() or 'loss_function' in globals()) and ('optimizer' in locals() or 'optimizer' in globals()):
        del(alexnet, loss_function, optimizer)
    alexnet = models.alexnet()
    alexnet.to(device=device)
    alexnet.train()
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(alexnet.parameters(), lr=0.003)
    # optimizer = optim.SGD(alexnet.parameters(), lr=0.003, momentum=0.9)
    
    csv_data=[['EPOCH','Loss','Accuracy']]
    train_data=DataProcess(label_file="./dataset/"+name+"/labels/TRAIN.csv",img_dir="./dataset/"+name+"/TRAIN/",transform=transform)
    test_data=DataProcess(label_file="./dataset/"+name+"/labels/TEST.csv",img_dir="./dataset/"+name+"/TEST/",transform=transform)
    train_loader= torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True, num_workers=0)
    test_loader= torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=False, num_workers=0)
    best=0
    EPOCH=60
    
    for epoch in range(EPOCH):
        print(f'Training epoch {epoch}...')

        Loss=train( name, train_loader, alexnet, loss_function, optimizer)

        last=test(best, name, test_loader)
        if best < last:
            best=last
        csv_data.append([epoch,Loss,last])
    
    write_csv(name=name, data=csv_data)

Training epoch 0...
Loss:  46.8438
 0.0000 <  23.8438
Accuracy:  23.8438%
Training epoch 1...
Loss:  1.4089
 23.8438 <  24.5045
Accuracy:  24.5045%
Training epoch 2...
Loss:  1.4091
 24.5045 <  24.9249
Accuracy:  24.9249%
Training epoch 3...
Loss:  1.4176
Accuracy:  24.5045%
Training epoch 4...
Loss:  1.4065
Accuracy:  23.8438%
Training epoch 5...
Loss:  1.4034
Accuracy:  24.5045%
Training epoch 6...
Loss:  1.3988
Accuracy:  24.9249%
Training epoch 7...
Loss:  1.4043
Accuracy:  23.8438%
Training epoch 8...
Loss:  1.3931
Accuracy:  24.9249%
Training epoch 9...
Loss:  1.3946
Accuracy:  23.8438%
Training epoch 10...
Loss:  1.3919
 24.9249 <  26.7267
Accuracy:  26.7267%
Training epoch 11...
Loss:  1.3892
Accuracy:  26.7267%
Training epoch 12...
Loss:  1.3909
Accuracy:  24.9249%
Training epoch 13...
Loss:  1.3900
Accuracy:  26.7267%
Training epoch 14...
Loss:  1.3861
Accuracy:  26.7267%
Training epoch 15...
Loss:  1.3883
Accuracy:  24.5045%
Training epoch 16...
Loss:  1.3900
Accuracy:  26.7

In [10]:
validate_defult=DataProcess(label_file="./dataset/default/labels/VALIDATE.csv",img_dir="./dataset/default/VALIDATE/",transform=transform)
validate_defult_loader= torch.utils.data.DataLoader(validate_defult, batch_size=10, shuffle=False, num_workers=0)
validate_no_bg=DataProcess(label_file="./dataset/no_bg/labels/VALIDATE.csv",img_dir="./dataset/no_bg/VALIDATE/",transform=transform)
validate_no_bg_loader= torch.utils.data.DataLoader(validate_no_bg, batch_size=10, shuffle=False, num_workers=0)
validate_random_bg=DataProcess(label_file="./dataset/random_bg/labels/VALIDATE.csv",img_dir="./dataset/random_bg/VALIDATE/",transform=transform)
validate_random_bg_loader= torch.utils.data.DataLoader(validate_random_bg, batch_size=10, shuffle=False, num_workers=0)

validate_loaders=[validate_defult_loader,validate_no_bg_loader,validate_random_bg_loader]

In [18]:
def validate(loader: torch.utils.data.DataLoader):
    validate_model = models.alexnet()
    validate_model.to(device)
    validate_model.load_state_dict(torch.load("./models/alexnet/random_bg_SGD_best.pt", weights_only=True))
    validate_model.eval()
    total=0
    correct=0
    with torch.no_grad():
        for data in loader:
            image, label =data
            image, label = image.to(device), label.to(device)
            outputs=validate_model(image)
            _, predicts = torch.max(outputs, 1)

            total += label.size(0)

            correct += (predicts==label).sum().item()
        
        accuracy = 100*correct/total
        print(f'Accuracy: {accuracy: .4f}')

In [19]:
for loader in validate_loaders:
    validate(loader)

Accuracy:  67.1014
Accuracy:  72.3278
Accuracy:  94.4278
