In [1]:
import os
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import torchvision.transforms as TF
import torch
import torch.nn as nn
import torchvision.models
from tqdm.auto import tqdm

In [2]:
data_path = Path('casting_512x512')

In [3]:
device = torch.device('cuda:0')

In [4]:
def create_model(num_classes, model_name):
    if model_name.startswith('resnet'):
        # model = getattr(torchvision.models, model_name)(weights='IMAGENET1K_V1')
        model = getattr(torchvision.models, model_name)(pretrained=True)
        model.fc = nn.Linear(model.fc.in_features, num_classes, bias=True)
    else:
        raise Exception(f'Unsupported model name: `{model_name}`')
    return model

In [5]:
class MyDs(Dataset):
    def __init__(self, pos_files, neg_files, tfm=None):
        self.pos_files = pos_files
        self.neg_files = neg_files
        self.tfm = tfm
    def __len__(self): return len(self.pos_files) + len(self.neg_files)
    def __getitem__(self, i):
        if i < len(self.pos_files):
            pf = data_path / 'def_front' / self.pos_files[i]
            lbl = 1
        else:
            pf = data_path / 'ok_front' / self.neg_files[i - len(self.pos_files)]
            lbl = 0
        image = Image.open(pf)
        if self.tfm is not None:
            image = self.tfm(image)
        return image, lbl

In [6]:
class ValueCollector:
    def __init__(self):
        self.values = []
    def put(self, vs):
        self.values.append(vs)
    def get(self):
        if len(self.values) == 0: 
            return []
        el = self.values[0]
        if isinstance(el, np.ndarray):
            return np.concatenate(self.values, axis=0)
        if isinstance(el, torch.Tensor):
            return torch.cat(self.values, dim=0)
        if isinstance(el, (list, tuple)):
            L = []
            for el in self.values:
                L += list(el)
            return L
        return self.values[:]

In [7]:
def get_accuracy(lbls, preds):
    return (preds == lbls).float().mean().item()

def get_accuracies(lbls, preds, num_classes):
    return [(preds[lbls == c] == c).float().mean().item() for c in range(num_classes)]

def get_balanced_accuracy(lbls, preds, num_classes):
    accs = get_accuracies(lbls, preds, num_classes)
    return torch.tensor(accs).mean().item()

def mean(L:list):
    if len(L) == 0: return np.nan
    return sum(L) / len(L)

In [8]:
def do_one_epoch(model, optimizer, loss_fn, dl, device):
    model = model.train()
    for inp, lbl in tqdm(dl, desc='Batch', leave=False):
        optimizer.zero_grad()
        out = model(inp.to(device))
        loss = loss_fn(out, lbl.to(device))
        loss.backward()
        optimizer.step()
    model = model.eval()
    
def evaluate(model, loss_fn, dl, device):
    model = model.eval()
    lbls_vc, preds_vc = ValueCollector(), ValueCollector()
    loss_vc = ValueCollector()
    for inp, lbl in tqdm(dl, desc='Batch', leave=False):
        with torch.set_grad_enabled(False):
            out = model(inp.to(device))
            loss = loss_fn(out, lbl.to(device))
        preds_vc.put(out.argmax(dim=1).detach().cpu())
        lbls_vc.put(lbl)
        loss_vc.put(loss.detach().cpu().item())
    preds, lbls = preds_vc.get(), lbls_vc.get()
    return {f'loss': mean(loss_vc.get()),
            f'acc_i': get_accuracy(lbls, preds),
            f'acc_c': get_balanced_accuracy(lbls, preds, 2)}

In [9]:
pos_files = sorted(os.listdir(data_path / 'def_front'))
neg_files = sorted(os.listdir(data_path / 'ok_front'))
np.random.seed(0)
np.random.shuffle(pos_files)
np.random.shuffle(neg_files)
_N = int(len(pos_files) * 0.8)
trn_pos_files, val_pos_files = pos_files[:_N], pos_files[_N:]
_N = int(len(neg_files) * 0.8)
trn_neg_files, val_neg_files = neg_files[:_N], neg_files[_N:]
_normalize = TF.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
tfm = TF.Compose([TF.Resize((256,256)), TF.ToTensor(), _normalize])
trn_ds = MyDs(trn_pos_files, trn_neg_files, tfm=tfm)
val_ds = MyDs(val_pos_files, val_neg_files, tfm=tfm)
trn_dl = DataLoader(trn_ds, batch_size=32, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=32, shuffle=False)

In [10]:
model = create_model(2, 'resnet34').to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss()

In [11]:
def to_str(res):
    return f'loss={res["loss"]:.4f} acc_i={res["acc_i"]*100:.2f} acc_c={res["acc_c"]*100:.2f}'

In [12]:
for epoch in tqdm(range(4), desc='Epoch'):
    do_one_epoch(model, optimizer, loss_fn, trn_dl, device)
    trn = evaluate(model, loss_fn, trn_dl, device)
    val = evaluate(model, loss_fn, val_dl, device)
    print(f'Epoch={epoch} TRN {to_str(trn)} VAL {to_str(val)}')

Epoch:   0%|          | 0/4 [00:00<?, ?it/s]

Batch:   0%|          | 0/33 [00:00<?, ?it/s]

Batch:   0%|          | 0/33 [00:00<?, ?it/s]

Batch:   0%|          | 0/9 [00:00<?, ?it/s]

Epoch=0 TRN loss=0.0345 acc_i=99.23 acc_c=99.36 VAL loss=0.0378 acc_i=99.23 acc_c=99.36


Batch:   0%|          | 0/33 [00:00<?, ?it/s]

Batch:   0%|          | 0/33 [00:00<?, ?it/s]

Batch:   0%|          | 0/9 [00:00<?, ?it/s]

Epoch=1 TRN loss=0.0043 acc_i=100.00 acc_c=100.00 VAL loss=0.0183 acc_i=99.23 acc_c=99.36


Batch:   0%|          | 0/33 [00:00<?, ?it/s]

Batch:   0%|          | 0/33 [00:00<?, ?it/s]

Batch:   0%|          | 0/9 [00:00<?, ?it/s]

Epoch=2 TRN loss=0.0009 acc_i=100.00 acc_c=100.00 VAL loss=0.0077 acc_i=99.62 acc_c=99.68


Batch:   0%|          | 0/33 [00:00<?, ?it/s]

Batch:   0%|          | 0/33 [00:00<?, ?it/s]

Batch:   0%|          | 0/9 [00:00<?, ?it/s]

Epoch=3 TRN loss=0.0029 acc_i=100.00 acc_c=100.00 VAL loss=0.0040 acc_i=100.00 acc_c=100.00


In [13]:
torch.save(model, 'trained_model.pt')