In [1]:
from typing import Callable, Tuple, List, Any

import torch

In [2]:
class DualSimpleCNN(torch.nn.Module):
    def __init__(self, layers_dim: List[int], wheter_concate: bool = False, pre_mlp_depth: int = 1):
        super().__init__()
        self.scaling_factor = 2 if wheter_concate else 1
        self.net1 = torch.nn.ModuleList([
            torch.nn.Sequential(torch.nn.Conv2d(layer_dim1, layer_dim2, 3, padding=1),
                                torch.nn.BatchNorm2d(layer_dim2),
                                torch.nn.ReLU(),
                                torch.nn.Conv2d(layer_dim2, layer_dim2, 3, padding=1),
                                torch.nn.BatchNorm2d(layer_dim2),
                                torch.nn.ReLU(),
                                torch.nn.MaxPool2d(2, 2))
            for layer_dim1, layer_dim2 in zip(layers_dim[:-3], layers_dim[1:-2])
        ])
        self.net2 = torch.nn.ModuleList([
            torch.nn.Sequential(torch.nn.Conv2d(layer_dim1, layer_dim2, 3, padding=1),
                                torch.nn.BatchNorm2d(layer_dim2),
                                torch.nn.ReLU(),
                                torch.nn.Conv2d(layer_dim2, layer_dim2, 3, padding=1),
                                torch.nn.BatchNorm2d(layer_dim2),
                                torch.nn.ReLU(),
                                torch.nn.MaxPool2d(2, 2))
            for layer_dim1, layer_dim2 in zip(layers_dim[:-3], layers_dim[1:-2])
        ])
        pre_mlp_channels = layers_dim[-3] * self.scaling_factor
        self.net3 = torch.nn.ModuleList([
            torch.nn.Sequential(torch.nn.Conv2d(pre_mlp_channels, pre_mlp_channels, 3, padding=1),
                                torch.nn.BatchNorm2d(pre_mlp_channels),
                                torch.nn.ReLU(),
                                torch.nn.Conv2d(pre_mlp_channels, pre_mlp_channels, 3, padding=1),
                                torch.nn.BatchNorm2d(pre_mlp_channels),
                                torch.nn.ReLU(),
                                torch.nn.MaxPool2d(2, 2))
            for _ in range(pre_mlp_depth)
        ])
        flatten_dim = 1024
        self.final_layer = torch.nn.Sequential(torch.nn.Linear(flatten_dim, layers_dim[-2]),
                                               torch.nn.BatchNorm1d(layers_dim[-2]),
                                               torch.nn.ReLU(),
                                               torch.nn.Linear(layers_dim[-2], layers_dim[-1]))
        
    def forward(self, x1, x2):
        for block in self.net1:
            x1 = block(x1)
        for block in self.net2:
            x2 = block(x2)
        y = torch.cat((x1, x2), dim=1) if self.scaling_factor == 2 else x1 + x2
        for block in self.net3:
            y = block(y)
        y = y.flatten(start_dim=1)
        y = self.final_layer(y)
        return y

In [None]:
model = DualSimpleCNN([3, 32, 64, 128, 256, 10], wheter_concate=True, pre_mlp_depth=1)
model

In [None]:
model(torch.randn(2, 3, 32, 32), torch.randn(2, 3, 32, 32)).shape

In [3]:
from torchvision.transforms import Compose, ToTensor, Normalize, Resize, RandomAffine, RandomHorizontalFlip
from torchvision.transforms import InterpolationMode
mean, std = (0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.262)

transform_train_proper = Compose([
    ToTensor(),
    RandomAffine(degrees=0, translate=(1/8, 1/8)),
    RandomHorizontalFlip(),
    Normalize(mean, std)
])

transform_train_blurred = Compose([
    ToTensor(),
    Resize(8, interpolation=InterpolationMode.BILINEAR, antialias=None),
    Resize(32, interpolation=InterpolationMode.BILINEAR, antialias=None),
    RandomAffine(degrees=0, translate=(1/8, 1/8)),
    RandomHorizontalFlip(),
    Normalize(mean, std)
])

In [4]:
import math
from torch.utils.data import DataLoader, Dataset

class SplitAndAugmentDataset(Dataset):
    def __init__(self, dataset, transform1, transform2, overlap=0.5):
        self.dataset = dataset
        self.transform1 = transform1
        self.transform2 = transform2
        self.with_overlap = overlap / 2 + 0.5

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        
        # Split the image into two halves
        # print(type(image))
        width, height = image.size
        width_ = math.ceil(width * self.with_overlap)
        image1 = image.crop((0, 0, width_, height))
        image2 = image.crop((width-width_, 0, width, height))
        
        print(image1.size, image2.size)
        print()

        image1 = self.transform1(image1)
        image2 = self.transform2(image2)

        return image1, image2, label

In [5]:
import torchvision
dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)
dataset1 = SplitAndAugmentDataset(dataset, transform_train_proper, transform_train_blurred, overlap=0.05)
dataloader = DataLoader(dataset1, batch_size=128, shuffle=True, num_workers=4)

Files already downloaded and verified


In [None]:
transform_train_blurred(dataset.data[0]).shape

In [None]:
a,b,c = dataset1.__getitem__(0)

In [None]:
a.shape

In [None]:
b.shape

In [None]:
a,b,c = dataset1.__getitem__(0)

In [None]:
a.shape, b.shape

In [None]:
op1 = Resize((32,8), interpolation=InterpolationMode.BILINEAR, antialias=None)
op2 = Resize(32,32), interpolation=InterpolationMode.BILINEAR, antialias=None)

In [None]:
x = op1(torch.randn(3, 32, 17))
y = op2(x)
print(x.shape, y.shape)

In [None]:
op1(torch.randn(3, 32, 17)).shape

In [10]:
import numpy as np

In [12]:
np.asarray(dataset[0][0]).shape

(32, 32, 3)

In [1]:
from src.utils.utils_trainer import manual_seed, find_paths

In [10]:
list(find_paths('reports'))

[]

In [3]:
import os
list(os.listdir('reports'))

['deficit, sgd, dual_cifar10, dual_simple_cnn_fp_0.0_lr_0.2_wd_0.0_window_0 overlap=0.0, both enabled']

In [4]:
path = 'reports/deficit, sgd, dual_cifar10, dual_simple_cnn_fp_0.0_lr_0.2_wd_0.0_window_0 overlap=0.0, both enabled'
path = os.path.join(os.getcwd(), path)
# path

In [5]:
find_paths(path)

['model_step_None_window_0.pth']