In [71]:
import random
from datetime import datetime

import torch
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm

QUANTITY_OF_TEST_CASES = 20

def params_to_string(params: dict) -> str:
    name = ""
    for key, value in params.items():
        name += f"{key}_{value}-"
    return name[:-1]

class RotationAndFlipLayer(torch.nn.Module):
    """
    A layer that converts a (B, L, C, W, H) into a (B * 8 * L, C, W, H)
    """
    def __init__(self):
        super().__init__()
        self.metadata = None
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch, levels, channels, width, height = x.shape

        self.metadata = { "batch": batch, "levels": levels, "transforms": 8 }
        
        x = x.reshape(batch * levels, channels, width, height)
        
        width_dim = len(x.shape) - 2
        height_dim = len(x.shape) - 1
        
        flipped: torch.Tensor = x.flip(dims=(width_dim,))
        flipped_rot90: torch.Tensor = flipped.rot90(k=1, dims=(width_dim, height_dim))
        flipped_rot180: torch.Tensor = flipped.rot90(k=2, dims=(width_dim, height_dim))
        flipped_rot270: torch.Tensor = flipped.rot90(k=3, dims=(width_dim, height_dim))
        rot90: torch.Tensor = x.rot90(k=1, dims=(width_dim, height_dim))
        rot180: torch.Tensor = x.rot90(k=2, dims=(width_dim, height_dim))
        rot270: torch.tensor = x.rot90(k=3, dims=(width_dim, height_dim))

        # return torch.cat((
        #     x,
        #     rot90,
        #     rot180,
        #     rot270,
        #     flipped, 
        #     flipped_rot90,
        #     flipped_rot180,
        #     flipped_rot270,
        # ), dim=0)
        
        return torch.cat((
            x,
            x,
            x,
            x,
            x, 
            x,
            x,
            x,
        ), dim=0)

def test_rotation_and_flip_layer():
    layer = RotationAndFlipLayer()
    
    for _ in range(QUANTITY_OF_TEST_CASES):
        batch = random.randint(1, 100)
        levels = random.randint(1, 100)

        x = torch.zeros((batch, levels, 1, 30, 30))
        out = layer.forward(x)

        transforms = layer.metadata["transforms"]

        expected = batch * levels * 8
        assert out.shape == torch.Size([expected, 1, 30, 30]), f"Test failed with batch {batch} and level {levels}"

test_rotation_and_flip_layer()

In [72]:
class DELIGHTModel(torch.nn.Module):
    """
    DELIGHT implementation written in torch.

    Allows inputs of ([B]atch, [L]evel, [C]hannels, [W]idth, [H]eight).

    """
    def __init__(self, debug: bool = False, nconv1: int = 52, nconv2: int = 57, nconv3: int = 41, nlinear1: int = 3280, nlinear2: int = 685):
        super().__init__()
        self.rot_and_flip = RotationAndFlipLayer()
        self.debug = debug
        self.CONV2D_ONE_OUT_CHANNELS = nconv1
        self.CONV2D_TWO_OUT_CHANNELS = nconv2
        self.CONV2D_THREE_OUT_CHANNELS = nconv3
        self.LINEAR_ONE_IN = nlinear1
        self.LINEAR_TWO_IN = nlinear2

        self.conv1 = torch.nn.Conv2d(
            in_channels=1,
            out_channels=self.CONV2D_ONE_OUT_CHANNELS, 
            kernel_size=3
        )
        self.conv2 = torch.nn.Conv2d(
            in_channels=self.conv1.out_channels, 
            out_channels=self.CONV2D_TWO_OUT_CHANNELS,
            kernel_size=3
        )
        self.conv3 = torch.nn.Conv2d(
            in_channels=self.conv2.out_channels, 
            out_channels=self.CONV2D_THREE_OUT_CHANNELS, 
            kernel_size=3
        )
        self.relu = torch.nn.ReLU()
        self.max_pool = torch.nn.MaxPool2d(kernel_size=2)
        self.flatten = torch.nn.Flatten()
        
        self.bottleneck = torch.nn.Sequential(
            self.conv1,
            self.relu,
            self.max_pool, 
            self.conv2,
            self.relu,
            self.max_pool,
            self.conv3,
            self.relu,
            self.flatten,
        )

        self.fc1 = torch.nn.Linear(in_features=self.LINEAR_ONE_IN, out_features=self.LINEAR_TWO_IN)
        self.fc2 = torch.nn.Linear(in_features=self.LINEAR_TWO_IN, out_features=2)
        self.tanh = torch.nn.Tanh()

    def _undo_transformations(self, x: torch.Tensor, n_transform: int, batch_size: int) -> torch.Tensor:
        chunks = []
        for transform in x.chunk(n_transform):
            partitions = [batch.reshape(-1) for batch in transform.chunk(batch_size)]
            chunks.append(torch.stack(partitions))
        return torch.stack(chunks).permute(1, 0, 2)
                
    def forward(self, x: torch.Tensor) -> torch.Tensor:                
        # Apply flips and rotations over level (L) dimension
        x = self.rot_and_flip(x)
        batch_size = self.rot_and_flip.metadata["batch"]
        levels = self.rot_and_flip.metadata["levels"]
        n_transforms = self.rot_and_flip.metadata["transforms"]
        
        # Bottleneck
        x = self.bottleneck(x)

        if self.debug:
            print(f"Pre bottleneck: {x.shape}")

        x = self._undo_transformations(x, n_transforms, batch_size)
        if self.debug:
            print(f"Post bottleneck: {x.shape}")
        # Linear
        x = self.fc1(x)
        x = self.tanh(x)
        x = self.fc2(x)
        
        return x.reshape((batch_size, n_transforms * 2))

In [73]:
### (Batch * Levels, Channels, Width, Heigth) 

_batch_size = 32
_levels = 5
_channels = 1

params = {
    "debug": True,
    "nconv1": 16,
    "nconv2": 32,
    "nconv3": 32,
    "nlinear1": 2560,
    "nlinear2": 128
}

_x = torch.zeros((_batch_size, _levels, _channels, 30, 30))
_model = DELIGHTModel(**params)
_out = _model.forward(_x)
assert _out.shape == torch.Size([_batch_size, 16])

Pre bottleneck: torch.Size([1280, 512])
Post bottleneck: torch.Size([32, 8, 2560])


In [75]:
import os
from enum import Enum
from typing import Union
from dataclasses import dataclass

import numpy as np
from torch.utils.data import Dataset

class CustomDatasetType(Enum):
    TRAIN = "TRAIN"
    TEST = "TEST"
    VALIDATION = "VALIDATION"
    
@dataclass
class CustomDatasetOptions:
    dataset_type: CustomDatasetType
    n_levels: int
    fold: int
    mask: bool
    object: bool

    def get_filenames(self) -> str:
        if self.dataset_type == CustomDatasetType.TRAIN:
            X = "X_train_nlevels%i_fold%i_mask%s_objects%s.npy" % (self.n_levels, self.fold, self.mask, self.object)
            y = "y_train_nlevels%i_fold%i_mask%s_objects%s.npy" % (self.n_levels, self.fold, self.mask, self.object)
        elif self.dataset_type == CustomDatasetType.TEST:
            X = "X_test_nlevels%i_mask%s_objects%s.npy" % (self.n_levels, self.mask, self.object)
            y = "y_test_nlevels%i_mask%s_objects%s.npy" % (self.n_levels, self.mask, self.object)
        else:
            X = "X_val_nlevels%i_fold%i_mask%s_objects%s.npy" % (self.n_levels, self.fold, self.mask, self.object)
            y = "y_val_nlevels%i_fold%i_mask%s_objects%s.npy" % (self.n_levels, self.fold, self.mask, self.object)
            
        return X, y
            
class CustomDataset(Dataset):    
    def __init__(self, options: CustomDatasetOptions, source: Union[str, None] = None):
        self.source = source if source is not None else "/home/fforster/SNHosts/data"
        X_path, y_path = options.get_filenames()
        self.X = torch.Tensor(np.load(os.path.join(self.source, X_path))).permute(0, 3, 1, 2)
        self.y = torch.Tensor(self.rotateY(np.load(os.path.join(self.source, y_path))))

    def rotateY(self, y: np.ndarray) -> np.ndarray:
        y90 = [-1, 1] * y[:, ::-1]
        y180 = [-1, 1] * y90[:, ::-1]
        y270 = [-1, 1] * y180[:, ::-1]
        yflip = [1, -1] * y
        yflip90 = [-1, 1] * yflip[:, ::-1]
        yflip180 = [-1, 1] * yflip90[:, ::-1]
        yflip270 = [-1, 1] * yflip180[:, ::-1]

        # return np.concatenate([
        #     y,
        #     y90,
        #     y180,
        #     y270,
        #     yflip,
        #     yflip90,
        #     yflip180,
        #     yflip270
        # ], axis=1)

        return np.concatenate([
            y,
            y,
            y,
            y,
            y,
            y,
            y,
            y
        ], axis=1)
    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx: int):
        X = self.X[idx]
        y = self.y[idx]

        if len(X.shape) == 3: # has no channel information
            levels, width, height = X.shape 
            X = X.reshape(levels, 1, width, height) # asume 1 channel information
        return X, y

In [76]:
train_opt = CustomDatasetOptions(
    dataset_type=CustomDatasetType.TRAIN,
    n_levels=5,
    fold=0,
    mask=False,
    object=True
)
test_opt = CustomDatasetOptions(
    dataset_type=CustomDatasetType.TEST,
    n_levels=5,
    fold=0,
    mask=False,
    object=True
)
val_opt = CustomDatasetOptions(
    dataset_type=CustomDatasetType.VALIDATION,
    n_levels=5,
    fold=0,
    mask=False,
    object=True
)

batch_size = 32
source = "/home/keviinplz/universidad/tesis/snhost/data"
train = CustomDataset(options=train_opt, source=source)
test = CustomDataset(options=test_opt, source=source)
val = CustomDataset(options=val_opt, source=source)

In [77]:
from torch.utils.data import DataLoader

params = {
    "debug": False,
    "nconv1": 16,
    "nconv2": 32,
    "nconv3": 32,
    "nlinear1": 2560,
    "nlinear2": 128
}

train_dl = DataLoader(train, batch_size=batch_size, shuffle=True)
test_dl = DataLoader(test, batch_size=batch_size, shuffle=True)
val_dl = DataLoader(val, batch_size=batch_size, shuffle=True)
model = DELIGHTModel(**params)

In [78]:
%load_ext tensorboard
ts = datetime.now().strftime('%Y%m%d_%H%M%S')
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-4)

name = params_to_string(params)
writer = SummaryWriter('runs/delight_{}_{}'.format(name, ts))

EPOCHS = 50

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [79]:
def train_one_epoch(epoch: int, tb_writer: SummaryWriter, device: str = "cuda"):
        running_loss = 0.
        last_loss = 0.
    
        # Here, we use enumerate(training_loader) instead of
        # iter(training_loader) so that we can track the batch
        # index and do some intra-epoch reporting
        pbar = tqdm(train_dl, leave=False, position=1)
        for i, data in enumerate(pbar):
            # Every data instance is an input + label pair
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
    
            # Zero your gradients for every batch!
            optimizer.zero_grad()
    
            # Make predictions for this batch
            outputs = model(inputs)
    
            # Compute the loss and its gradients
            loss = loss_fn(outputs, labels)
            loss.backward()
    
            # Adjust learning weights
            optimizer.step()
    
            # Gather data and report
            running_loss += loss.item()
            if i % batch_size == batch_size - 1:
                last_loss = running_loss / batch_size # loss per batch
                pbar.set_description('batch {} loss: {}'.format(i + 1, last_loss))
                tb_x = epoch * len(train_dl) + i + 1
                tb_writer.add_scalar('Loss/train', last_loss, tb_x)
                running_loss = 0.
    
        return last_loss

In [80]:
%tensorboard --logdir runs

best_vloss = 1_000_000.
device = "cuda"

os.makedirs("states", exist_ok=True)

pbar = tqdm(range(EPOCHS), leave=False, position=0)

for epoch in pbar:
    pbar.set_description("Running epoch %s" % epoch)

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    model.to(device)
    avg_loss = train_one_epoch(epoch, writer, device)


    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(val_dl):
            vinputs, vlabels = vdata
            vinputs = vinputs.to(device)
            vlabels = vlabels.to(device)
            
            voutputs = model(vinputs)

            vloss = loss_fn(voutputs, vlabels)
            running_vloss += vloss

    avg_vloss = running_vloss / (i + 1)
    pbar.set_description("LOSS train %s valid %s" % (avg_loss, avg_vloss), refresh=False)

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch + 1)
    writer.flush()

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = 'states/model_{}_{}'.format(ts, epoch)
        torch.save(model.state_dict(), model_path)

    epoch += 1

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

  0%|          | 0/301 [00:00<?, ?it/s]

In [81]:
best_vloss, avg_loss

(tensor(84.9670, device='cuda:0'), 6.793004307895899)