In [40]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import torch.optim as optim
from efficientnet_pytorch import EfficientNet
from pathlib import Path
import numpy as np
import multiprocessing
import math

In [22]:
# check if cuda is available
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


## Model

In [23]:
v = 0     # model version
in_c = 2  # number of input channels
num_c = 1 # number of classes to predict

In [24]:
# The optical flow input will look like this
# of = torch.randn(1,2,640,480)

In [25]:
model = EfficientNet.from_pretrained(f'efficientnet-b{v}', in_channels=in_c, num_classes=num_c)
model.to(device);

Loaded pretrained weights for efficientnet-b0


#### The output of the model will look like this

In [26]:
# of = of.to(device)
# model(of).item()

## Data

In [27]:
# directory with the optical flow images
of_dir = '../opical-flow-estimation-with-RAFT/output'
# labels as txt file
labels_f = 'train.txt'

In [28]:
class OFDataset(Dataset):
    def __init__(self, of_dir, label_f):
        self.len = len(list(Path(of_dir).glob('*.npy')))
        self.of_dir = of_dir
        self.label_file = open(label_f).readlines()
    def __len__(self): return self.len
    def __getitem__(self, idx):
        of_array = np.load(Path(self.of_dir)/f'{idx}.npy')
        of_tensor = torch.squeeze(torch.Tensor(of_array))
        label = float(self.label_file[idx].split()[0])
        return [of_tensor, label]

In [29]:
ds = OFDataset(of_dir, labels_f)

In [30]:
# 80% of data for training
# 20% of data for validation
train_split = .8

In [31]:
ds_size = len(ds)
indices = list(range(ds_size))
split = int(np.floor(train_split * ds_size))
train_idx, val_idx = indices[:split], indices[split:]

In [32]:
sample = ds[3]
assert type(sample[0]) == torch.Tensor
assert type(sample[1]) == float

In [33]:
train_sampler = SubsetRandomSampler(train_idx)
val_sampler = SubsetRandomSampler(val_idx)

In [34]:
cpu_cores = multiprocessing.cpu_count()
cpu_cores

8

In [35]:
train_dl = DataLoader(ds, batch_size=8, sampler=train_sampler, num_workers=0)
val_dl = DataLoader(ds, batch_size=8, sampler=val_sampler, num_workers=0)

In [None]:
def plot(train_loss,val_loss,title):
    N = len(train_loss)
    plt.plot(range(N),train_loss,label = 'train_loss')
    plt.plot(range(N),val_loss, label = 'val_loss')
    plt.title(title)
    plt.xlabel("epoch")
    plt.ylabel("MSE")
    plt.grid(True)
    plt.legend()
    plt.savefig("./loss/result.png")
    plt.show()

## Train

In [36]:
epochs = 25 
log_train_steps = 100

In [37]:
criterion = nn.MSELoss()
opt = optim.Adam(model.parameters())

In [None]:
history_train_loss = []
history_val_loss = []
best_loss = math.inf
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    train_losses = []
    for i, sample in enumerate(train_dl):
        of_tensor = sample[0].cuda()
        label = sample[1].float().cuda()
        opt.zero_grad()
        pred = torch.squeeze(model(of_tensor))
        loss = criterion(pred, label)
        train_losses.append(loss)
        loss.backward()
        opt.step()
    mean_train_loss = sum(train_losses)/len(train_losses)
    history_train_loss.append(mean_train_loss)
    print(f'{epoch}t: {mean_train_loss}')
    # validation
    model.eval()
    val_losses = []
    with torch.no_grad():
        for j, val_sample in enumerate(val_dl):
            of_tensor = val_sample[0].cuda()
            label = val_sample[1].float().cuda()
            pred = torch.squeeze(model(of_tensor))
            loss = criterion(pred, label)
            val_losses.append(loss)
        mean_val_loss = sum(val_losses)/len(val_losses)
        if(mean_val_loss < best_loss):
            torch.save(model.state_dict(), 'model/b0.pth')
            best_loss = mean_val_loss
        history_val_loss.append(mean_val_loss)
        print(f'{epoch}: {mean_val_loss}')

0t: 2.590954065322876
0: 2.68108868598938
1t: 1.8530006408691406
1: 4.91492223739624
2t: 2.106901168823242
2: 2.2981743812561035
3t: 1.0425323247909546
3: 11.310953140258789
4t: 1.259174108505249
4: 17.798213958740234
5t: 1.0620850324630737
5: 2.86309814453125
6t: 1.119244933128357
6: 7.951549530029297
7t: 0.6982173323631287
7: 1.9866344928741455
8t: 0.8450481295585632
8: 2.269864797592163
9t: 0.9126380085945129
9: 1.6741011142730713
10t: 0.6199302673339844
10: 2.419679880142212
11t: 0.5272167325019836
11: 4.0807719230651855
12t: 0.808750569820404
12: 2.4169180393218994
13t: 0.40372994542121887
13: 2.5497241020202637
14t: 0.6809329986572266
14: 2.3896114826202393


In [None]:
plot(history_train_loss,history_val_loss,"efficientnetb0")