In [1]:
import numpy as np
import PIL

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.models as models
import torchvision.transforms as transforms

import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader

import os

In [2]:
torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [10]:
class DataProcess():
    def __init__(self, label_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(label_file, header=None)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = PIL.Image.open(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

In [11]:
train_data=DataProcess(label_file="./dataset/default/labels/TRAIN.csv",img_dir="./dataset/default/TRAIN/",transform=transform)
test_data=DataProcess(label_file="./dataset/default/labels/TEST.csv",img_dir="./dataset/default/TEST/",transform=transform)

In [12]:
train_loader= torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True, num_workers=0)
test_loader= torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=False, num_workers=0)

In [13]:
resnet18 = models.resnet18()
resnet18.to(device=device)
resnet18.train()
loss_function = nn.CrossEntropyLoss()
optimizer = optim.SGD(resnet18.parameters(), lr=0.003, momentum=0.9)

In [14]:
def test(best):
    test=models.resnet18()
    test.to(device)
    test.load_state_dict(torch.load("./models/resnet18/default_last.pt", weights_only=True))
    test.eval()
    total=0
    correct=0
    with torch.no_grad():
        for i, data in enumerate(test_loader):
            images_test, labels_test = data
            images_test, labels_test = images_test.to(device), labels_test.to(device)
            outputs=test(images_test)
            _, predicts = torch.max(outputs, 1)

            total += labels_test.size(0)

            correct += (predicts==labels_test).sum().item()
        
        accuracy = 100*correct/total
        
        if best < accuracy:
            torch.save(test.state_dict(), "./models/resnet18/default_best.pt")
        
        print(f'Accuracy: {accuracy}%')

    return accuracy

In [16]:
def train():
    running_loss = 0.0
    for i, data in enumerate(train_loader):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        output = resnet18(inputs)
        # print(output[0].size())
        loss = loss_function(output, labels)
        loss.backward()
        optimizer.step()

    running_loss += loss.item()

    print(f'Loss: {running_loss / len(train_loader): .4f}')

    torch.save(resnet18.state_dict(), "./models/resnet18/default_last.pt")

In [17]:
best=0
EPOCH=60
for epoch in range(EPOCH):
    print(f'Training epoch {epoch}...')

    train()

    last=test(best)
    if best < last:
        best=last

Training epoch 0...
Loss:  0.0214
Accuracy: 25.46082949308756%
Training epoch 1...
Loss:  0.0169
Accuracy: 44.75806451612903%
Training epoch 2...
Loss:  0.0447
Accuracy: 55.24193548387097%
Training epoch 3...
Loss:  0.0455
Accuracy: 57.085253456221196%
Training epoch 4...
Loss:  0.0123
Accuracy: 64.8041474654378%
Training epoch 5...
Loss:  0.0357
Accuracy: 65.09216589861751%
Training epoch 6...
Loss:  0.0595
Accuracy: 62.096774193548384%
Training epoch 7...
Loss:  0.0059
Accuracy: 67.10829493087557%
Training epoch 8...
Loss:  0.0630
Accuracy: 82.20046082949308%
Training epoch 9...
Loss:  0.0415
Accuracy: 89.34331797235023%
Training epoch 10...
Loss:  0.0310
Accuracy: 83.92857142857143%
Training epoch 11...
Loss:  0.0065
Accuracy: 83.81336405529954%
Training epoch 12...
Loss:  0.0029
Accuracy: 95.44930875576037%
Training epoch 13...
Loss:  0.0048
Accuracy: 93.26036866359448%
Training epoch 14...
Loss:  0.0027
Accuracy: 96.54377880184332%
Training epoch 15...
Loss:  0.0030
Accuracy: 98.3

In [18]:
validate_defult=DataProcess(label_file="./dataset/default/labels/VALIDATE.csv",img_dir="./dataset/default/VALIDATE/",transform=transform)
validate_defult_loader= torch.utils.data.DataLoader(validate_defult, batch_size=10, shuffle=False, num_workers=0)
validate_no_bg=DataProcess(label_file="./dataset/no_bg/labels/VALIDATE.csv",img_dir="./dataset/no_bg/VALIDATE/",transform=transform)
validate_no_bg_loader= torch.utils.data.DataLoader(validate_no_bg, batch_size=10, shuffle=False, num_workers=0)
validate_random_bg=DataProcess(label_file="./dataset/random_bg/labels/VALIDATE.csv",img_dir="./dataset/random_bg/VALIDATE/",transform=transform)
validate_random_bg_loader= torch.utils.data.DataLoader(validate_random_bg, batch_size=10, shuffle=False, num_workers=0)

validate_loaders=[validate_defult_loader,validate_no_bg_loader,validate_random_bg_loader]

In [19]:
def validate(loader: torch.utils.data.DataLoader):
    validate_model = models.resnet18()
    validate_model.to(device)
    validate_model.load_state_dict(torch.load("./models/resnet18/default_best.pt", weights_only=True))
    validate_model.eval()
    total=0
    correct=0
    with torch.no_grad():
        for data in loader:
            image, label =data
            image, label = image.to(device), label.to(device)
            outputs=validate_model(image)
            _, predicts = torch.max(outputs, 1)

            total += label.size(0)

            correct += (predicts==label).sum().item()
        
        accuracy = 100*correct/total
        print(f'Accuracy: {accuracy: .4f}')

In [20]:
for loader in validate_loaders:
    validate(loader)

Accuracy:  98.8145
Accuracy:  26.9596
Accuracy:  36.4889
