In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F 
import torch.optim as optim
import torchvision
from torchvision import transforms, utils
import time
from sklearn.metrics import precision_recall_curve,precision_score,recall_score,accuracy_score

In [8]:
class Cfg:
    batch_size  = 16
    learning_rate = 0.001
    weight_decay = 0.001
    max_epoch = 50
    log_interval = 50
    val_interval = 200
cfg = Cfg()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
DATA_PATH = "ds/"

data_whole = torchvision.datasets.ImageFolder(DATA_PATH,
                                                transform=transforms.Compose([
                                                transforms.Resize((64,64)),
                                                transforms.ToTensor()])
                                            )

In [4]:
train_set_size = int(len(data_whole) * 0.5)
utility_set_size = len(data_whole) - train_set_size

print(train_set_size,utility_set_size)

1913 1913


In [5]:
data_train, data_val = torch.utils.data.random_split(data_whole, 
                                                                   [train_set_size, utility_set_size], 
                                                                   generator=torch.Generator().manual_seed(42))
data_loader_tr = torch.utils.data.DataLoader(data_train, batch_size=cfg.batch_size, shuffle=True)
data_loader_va = torch.utils.data.DataLoader(data_val, batch_size=32, shuffle=True,drop_last=False)

In [6]:
len(data_train)
len(data_val)

1913

In [12]:
model = torchvision.models.vgg13_bn(num_classes=2).to(device)

In [13]:
optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate,weight_decay=cfg.weight_decay)
criterion = nn.BCEWithLogitsLoss()
val_acc = []
train_acc = []
train_loss = []

In [None]:
for epoch in range(cfg.max_epoch + 1):
    acc_num = 0
    for batch_idx, (data, target) in enumerate(data_loader_tr):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, torch.nn.functional.one_hot(target,num_classes=2).float())
        loss.backward()
        optimizer.step()
        acc_num += (target == torch.argmax(output,axis=1)+0).sum().item()
        
        if batch_idx % cfg.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}, Acc = {}'.format(
                epoch, batch_idx * len(data), len(data_loader_tr.dataset),
                100. * batch_idx / len(data_loader_tr), loss.item(),acc_num/((batch_idx+1) * cfg.batch_size)))
        train_acc.append(acc_num/((batch_idx+1) * cfg.batch_size))
        train_loss.append(loss.item())
    if (epoch) % 5 == 0:
        print("Validating...")
        model.eval()
        all_out = np.empty(shape=(0,))
        all_label = np.empty(shape=(0,))
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(data_loader_va):
                data, target = data.to(device), target.to(device)
                optimizer.zero_grad()
                output = model(data)
                all_out = np.concatenate((all_out,torch.argmax(output,axis=1).cpu().numpy()))
                all_label = np.concatenate((all_label,target.cpu().numpy()))

        precision, recall, acc = precision_score(all_label,all_out,average='macro'),recall_score(all_label,all_out,average='macro'),accuracy_score(all_label,all_out)
        print("P:",precision,"R:",recall)
        print(acc)
        if (acc > 0.98):
            break
        val_acc.append(acc)
        model.train()
    if (acc > 0.98):
        break

Validating...
