In [None]:
import os
import PIL.Image
import numpy as np
import pandas as pd
import torch
import torchvision
import cv2
import numpy as np
import torch.nn.functional as F
from tensorboardX import SummaryWriter
from model import AttentionModel

In [None]:
# define parameters
mydatadir='dataset/'
img_size = 224
batch_size = 64
num_workers = 10
lr = 1e-3
max_epochs = 100
ckpt_dir = ''
class_names = ['no-STAS', 'STAS']

split_train = ''
split_val = ''

# Attention model
model = AttentionModel(len(class_names))
print(model)

writer = SummaryWriter(os.path.join(ckpt_dir, 'log'))

In [None]:
# define dataset
class ClassificationDataset(torch.utils.data.Dataset):
    """ dataset """

    def __init__(self, datadir='', split='train', transforms=None):
        self.split = split
        self.transforms = transforms
        
        self.imdb = []
        stats = {lbl: 0 for lbl in class_names}
        
        with open(os.path.join(datadir, split + '.txt')) as ff:
            for line in ff:
                line = line.strip()
                imgname, label = line.split('t')
                label = int(label)
                imgpath = os.path.join(os.path.join(datadir, '', imgname))
                self.imdb.append({
                    'imgpath': imgpath,
                    'label': label,
                })
                stats[class_names[label]] = stats[class_names[label]] + 1
                
        print('split: %s, total image num: %d' % (split, len(self.imdb)))
        for classname in stats:
            print('    %s: %d' % (classname, stats[classname]))
    
    def __getitem__(self, index):
        # Load the image
        imgpath = self.imdb[index]["imgpath"]
        label = self.imdb[index]["label"]
        
        # read image
        img = PIL.Image.open(imgpath).convert('RGB')
        
        if self.transforms is not None:
            img = self.transforms(img)

        return img, label

    def __len__(self):
        return len(self.imdb)

In [None]:
# mydatadir='dataset/dog_vs_cat'

normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

transforms_train = torchvision.transforms.Compose([
    torchvision.transforms.Resize([img_size, img_size]), 
    torchvision.transforms.ColorJitter(0.5, 0.5, 0.5, 0.4),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.RandomRotation(15), 
    torchvision.transforms.ToTensor(), 
    normalize,
])

transforms_test = torchvision.transforms.Compose([
    torchvision.transforms.Resize([img_size, img_size]), 
    torchvision.transforms.ToTensor(),
    normalize,
])

train_dataset = ClassificationDataset(datadir=mydatadir, split=split_train, transforms=transforms_train)
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size, 
    shuffle=True,
    num_workers=num_workers, 
    pin_memory=True, 
    drop_last=True,
    )
test_dataset = ClassificationDataset(datadir=mydatadir, split=split_val, transforms=transforms_test)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=batch_size, 
    shuffle=False, 
    num_workers=num_workers, 
    pin_memory=True, 
    drop_last=False,
)

In [None]:
model = model.cuda()
model = torch.nn.DataParallel(model)
optimizer = torch.optim.Adam(model.parameters(), lr)
criterian = torch.nn.CrossEntropyLoss()

In [None]:
# ckpt_dir = './ckpt/ckpt_classification'
os.makedirs(ckpt_dir, exist_ok=True)
best_score = 0
# max_epochs = 50
for epoch in range(max_epochs):
    
    # train model
    
    model.train()
    running_loss = 0
    running_acc = 0
    for idx, (inputs, labels) in enumerate (train_loader):
        inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
        output = model(inputs)
        loss = criterian(output, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predict = torch.max(output, 1)
        correct_num = (predict == labels).sum()
        running_acc += correct_num.item()
    
    running_loss /= len(train_loader)
    running_acc /= train_dataset.__len__()
    
    
    # test model after each epoch
    model.eval()
    with torch.no_grad():
        testloss = 0.
        testacc = 0.
        for idx, (inputs, labels) in enumerate (test_loader):
            inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
            output = model(inputs)
            loss = criterian(output, labels)
            testloss += loss.item()
            _, predict = torch.max(output, 1)
            num_correct = (predict == labels).sum()
            testacc += num_correct.item()

        testloss /= len(test_loader)
        testacc /= test_dataset.__len__()
    
    writer.add_scalars("loss",{"train":running_loss,"test":testloss},epoch+1)
    writer.add_scalars("acc",{"train":running_acc,"test":testacc},epoch+1)
    
    print("[%d/%d] Train loss: %.4f, Train acc: %.3f; Test loss: %.4f, Test acc: %.3f" %(
        epoch + 1, max_epochs, 
        running_loss, running_acc,
        testloss, testacc,
    ))
    
    if testacc > best_score:
        best_score = testacc
        torch.save(model.module, os.path.join(ckpt_dir, 'best.pth'))
writer.close()