In [1]:
import numpy as np
import PIL

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.models as models
import torchvision.transforms as transforms

import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader

import os
import csv

In [2]:
torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [4]:
class DataProcess():
    def __init__(self, label_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(label_file, header=None)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = PIL.Image.open(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

In [5]:
def test(best, name, test_loader):
    test=models.resnet18()
    test.to(device)
    test.load_state_dict(torch.load("./models/resnet18/"+name+"_Adam_last.pt", weights_only=True))
    test.eval()
    total=0
    correct=0
    with torch.no_grad():
        for i, data in enumerate(test_loader):
            images_test, labels_test = data
            images_test, labels_test = images_test.to(device), labels_test.to(device)
            outputs=test(images_test)
            _, predicts = torch.max(outputs, 1)

            total += labels_test.size(0)

            correct += (predicts==labels_test).sum().item()
        
        accuracy = 100*correct/total
        
        if best < accuracy:
            print(f'{best:.4f} < {accuracy: .4f}')
            torch.save(test.state_dict(), "./models/resnet18/"+name+"_Adam_best.pt")
        
        print(f'Accuracy: {accuracy}%')

    return accuracy

In [7]:
def train( name, train_loader, resnet18, loss_function, optimizer):
    running_loss = 0.0
    for i, data in enumerate(train_loader):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        output = resnet18(inputs)
        # print(output[0].size())
        loss = loss_function(output, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f'Loss: {running_loss / len(train_loader): .4f}; {running_loss:.4f};{len(train_loader):.4f}')

    torch.save(resnet18.state_dict(), "./models/resnet18/"+name+"_Adam_last.pt")
    return (running_loss / len(train_loader))

In [8]:
def write_csv(name, data):
    with open("./models/resnet18/"+name+"_dataset_Adam.csv",'w+',newline='') as file:
        writer=csv.writer(file)
        writer.writerows(data)

In [9]:
FOLDERS=["default","no_bg","random_bg"]

In [10]:
for name in FOLDERS:
    torch.cuda.empty_cache()
    if ('resnet18' in locals() or 'resnet18' in globals()) and ('loss_function' in locals() or 'loss_function' in globals()) and ('optimizer' in locals() or 'optimizer' in globals()):
        del(resnet18, loss_function, optimizer)
    resnet18 = models.resnet18()
    resnet18.to(device=device)
    resnet18.train()
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(resnet18.parameters(), lr=0.003)
    # optimizer = optim.SGD(resnet18.parameters(), lr=0.003, momentum=0.9)
    
    csv_data=[['EPOCH','Loss','Accuracy']]
    
    train_data=DataProcess(label_file="./dataset/"+name+"/labels/TRAIN.csv",img_dir="./dataset/"+name+"/TRAIN/",transform=transform)
    test_data=DataProcess(label_file="./dataset/"+name+"/labels/TEST.csv",img_dir="./dataset/"+name+"/TEST/",transform=transform)
    train_loader= torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True, num_workers=0)
    test_loader= torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=False, num_workers=0)
    
    best=0
    EPOCH=60
    
    for epoch in range(EPOCH):
        print(f'Training epoch {epoch}...')

        Loss=train( name, train_loader, resnet18, loss_function, optimizer)

        last=test(best, name, test_loader)
        if best < last:
            best=last
        csv_data.append([epoch,Loss,last])
    
    write_csv(name=name, data=csv_data)

Training epoch 0...
Loss:  1.6725; 88.6442;53.0000
0.0000 <  30.0300
Accuracy: 30.03003003003003%
Training epoch 1...
Loss:  1.0055; 53.2909;53.0000
30.0300 <  41.5616
Accuracy: 41.56156156156156%
Training epoch 2...
Loss:  0.8610; 45.6344;53.0000
41.5616 <  55.6156
Accuracy: 55.61561561561562%
Training epoch 3...
Loss:  0.7515; 39.8271;53.0000
Accuracy: 54.8948948948949%
Training epoch 4...
Loss:  0.6814; 36.1120;53.0000
55.6156 <  74.7748
Accuracy: 74.77477477477477%
Training epoch 5...
Loss:  0.4834; 25.6188;53.0000
Accuracy: 50.210210210210214%
Training epoch 6...
Loss:  0.4120; 21.8334;53.0000
74.7748 <  80.3604
Accuracy: 80.36036036036036%
Training epoch 7...
Loss:  0.2941; 15.5862;53.0000
Accuracy: 53.693693693693696%
Training epoch 8...
Loss:  0.2461; 13.0430;53.0000
80.3604 <  96.2763
Accuracy: 96.27627627627628%
Training epoch 9...
Loss:  0.2269; 12.0280;53.0000
Accuracy: 89.54954954954955%
Training epoch 10...
Loss:  0.1823; 9.6613;53.0000
Accuracy: 91.11111111111111%
Traini

In [13]:
validate_defult=DataProcess(label_file="./dataset/default/labels/VALIDATE.csv",img_dir="./dataset/default/VALIDATE/",transform=transform)
validate_defult_loader= torch.utils.data.DataLoader(validate_defult, batch_size=10, shuffle=False, num_workers=0)
validate_no_bg=DataProcess(label_file="./dataset/no_bg/labels/VALIDATE.csv",img_dir="./dataset/no_bg/VALIDATE/",transform=transform)
validate_no_bg_loader= torch.utils.data.DataLoader(validate_no_bg, batch_size=10, shuffle=False, num_workers=0)
validate_random_bg=DataProcess(label_file="./dataset/random_bg/labels/VALIDATE.csv",img_dir="./dataset/random_bg/VALIDATE/",transform=transform)
validate_random_bg_loader= torch.utils.data.DataLoader(validate_random_bg, batch_size=10, shuffle=False, num_workers=0)

validate_loaders=[validate_defult_loader,validate_no_bg_loader,validate_random_bg_loader]

In [17]:
def validate(loader: torch.utils.data.DataLoader):
    validate_model = models.resnet18()
    validate_model.to(device)
    validate_model.load_state_dict(torch.load("./models/resnet18/random_bg_Adam_best.pt", weights_only=True))
    validate_model.eval()
    total=0
    correct=0
    with torch.no_grad():
        for data in loader:
            image, label =data
            image, label = image.to(device), label.to(device)
            outputs=validate_model(image)
            _, predicts = torch.max(outputs, 1)

            total += label.size(0)

            correct += (predicts==label).sum().item()
        
        accuracy = 100*correct/total
        print(f'Accuracy: {accuracy: .4f}')

In [18]:
for loader in validate_loaders:
    validate(loader)

Accuracy:  66.4493
Accuracy:  66.3895
Accuracy:  94.9670
