In [11]:
import torch
import torch.nn as nn
from torch.nn import functional as F

from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam

import numpy as np
from img_funcs import *
from utils import *

from tqdm import tqdm

In [2]:
imagesfiles_a = r"...\drumstick"
imagesfiles_b = r"...\duebel"
imagesfiles_a = append_slash_to_dirpath_if_not_present(imagesfiles_a)
imagesfiles_b = append_slash_to_dirpath_if_not_present(imagesfiles_b)

In [3]:
class BinaryClassificationDataset(Dataset):
    def __init__(self, imgs_a, imgs_b, size):
        self.size = size
        self.labels = torch.cat((torch.zeros(len(imgs_a)), torch.ones(len(imgs_b))))
        self.imgs = imgs_a + imgs_b

    def __len__(self):
        return len(self.imgs)

    def forward_transform(self, img):
        #img = np.transpose(img, (2, 0, 1))
        img = torch.Tensor(img)
        return img

    def __getitem__(self, idx):
        img, label = self.imgs[idx], self.labels[idx]
        img = augment(img)
        img = self.forward_transform(img)
        return img, label

In [4]:
imgs_a = import_images_from_dir(imagesfiles_a, max_num_images=184)
imgs_b = import_images_from_dir(imagesfiles_b, max_num_images=184)

def validation_split(imgs_a, imgs_b, split=0.2):
    split_a = round(len(imgs_a) * split)
    split_b = round(len(imgs_b) * split)
    return imgs_a[split_a:], imgs_b[split_b:], imgs_a[:split_a], imgs_b[:split_b]

Early stopping because max_num_images is reached.
184  images successfully loaded.
Early stopping because max_num_images is reached.
184  images successfully loaded.


In [5]:
train_a, train_b, test_a, test_b = validation_split(imgs_a, imgs_b, split=0.2)

SIZE = 128
BATCH_SIZE = 16
device = 'cpu'

In [6]:
train_dataset = BinaryClassificationDataset(train_a, train_b, SIZE)
test_dataset  = BinaryClassificationDataset(test_a, test_b, SIZE)
train_dataset = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=False)
test_dataset  = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=False)

In [7]:
class Model(nn.Module):
    def __init__(self,
        n_input_neurons,
        n_hidden_neurons):
        super().__init__()

        self.layer1 = nn.Linear(n_input_neurons, n_hidden_neurons)
        self.layer2 = nn.Linear(n_hidden_neurons, 1)

    def forward(self, img):
        x = torch.flatten(img, start_dim=1)
        x = self.layer1(x)
        x = F.relu(x)
        x = self.layer2(x)
        x = F.sigmoid(x)
        return x

In [8]:
M = Model(3 * SIZE**2, 1000)
opt = Adam(M.parameters(), lr=1e-4)

In [17]:
def train_epoch():
    epoch_loss = 0
    pbar = tqdm(train_dataset, total=len(train_dataset))
    for img_batch, label_batch in pbar:
        opt.zero_grad()
        pred = M(img_batch)
        loss = torch.mean(torch.abs(pred - label_batch))
        loss.backward()
        opt.step()
        loss = loss.detach().cpu().numpy()
        epoch_loss += loss
        pbar.set_description("mae: "+str(np.round(loss, 4)))
    epoch_loss /= len(train_dataset)
    return epoch_loss

def evaluate():
    with torch.no_grad():
        test_loss = 0
        for img_batch, label_batch in test_dataset:
            pred = M(img_batch)
            loss = torch.mean(torch.abs(pred - label_batch))
            test_loss += loss.detach().cpu().numpy()
        test_loss /= len(test_dataset)
        return test_loss


In [18]:
for i in range(10):
    train_loss = train_epoch()
    eval_loss = evaluate()

    print('train: '+str(train_loss)+' eval: ' + str(eval_loss))

mae: 0.5: 100%|██████████| 19/19 [00:16<00:00,  1.13it/s]   


train: 0.5024671052631579 eval: 0.5130625009536743


mae: 1.0: 100%|██████████| 19/19 [00:15<00:00,  1.24it/s]   


train: 0.5209703947368421 eval: 0.515625


mae: 0.3889: 100%|██████████| 19/19 [00:15<00:00,  1.23it/s]


train: 0.49826388923745407 eval: 0.515625


mae: 0.3333: 100%|██████████| 19/19 [00:15<00:00,  1.25it/s]


train: 0.49081688648776006 eval: 0.4884999990463257


mae: 0.5: 100%|██████████| 19/19 [00:15<00:00,  1.24it/s]   


train: 0.501233552631579 eval: 0.496875


mae: 0.5: 100%|██████████| 19/19 [00:15<00:00,  1.22it/s]   


train: 0.5053453947368421 eval: 0.4845625042915344


mae: 0.5: 100%|██████████| 19/19 [00:16<00:00,  1.18it/s]   


train: 0.5020559210526315 eval: 0.503125


mae: 0.3333: 100%|██████████| 19/19 [00:16<00:00,  1.13it/s]


train: 0.49904057069828633 eval: 0.48125


mae: 0.625:  11%|█         | 2/19 [00:01<00:15,  1.09it/s] 


KeyboardInterrupt: 

In [28]:
torch.flatten(x, start_d)

torch.Size([16, 128, 128, 3])