In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch import nn

from sklearn.model_selection import train_test_split

from common import data

from common.training import training_loop


from common.models import resnet
from common.models import deeplab

In [3]:
paths = data.get_dataset_paths("../data")
train, test = train_test_split(paths, test_size=0.1, random_state=42)
dataset = data.Marconi100Dataset(test, data.Scaling.STANDARD)

horizon = 1024
stride = 500
batch_size = 4
num_workers = 0

dataset_test = data.UnfoldedDataset(dataset, horizon=horizon, stride=stride)
test_loader = DataLoader(
    dataset_test,
    batch_size,
    shuffle=False,
    num_workers=num_workers,
    persistent_workers=(num_workers != 0),
)

Loading: 100%|██████████| 25/25 [00:15<00:00,  1.63it/s]


In [4]:
batch = next(iter(test_loader))
batch["data"].shape

torch.Size([4, 1024, 460])

In [5]:
net = resnet.ResNet(
    resnet.ResNetFeatures(
        resnet.Bottleneck,
        resnet.RESNET50_LAYERS,
        num_features=data.NUM_FEATURES,
    ),
    num_classes=2,
)
res = net(batch["data"])

res.shape


torch.Size([4, 2])

In [6]:
net = resnet.ResNetFeatures(
    resnet.Bottleneck,
    resnet.RESNET50_LAYERS,
    return_layers=[resnet.LAYER_1, resnet.LAYER_2, resnet.LAYER_3, resnet.LAYER_4],
    replace_stride_with_dilation=[False, True, True],
    num_features=data.NUM_FEATURES,
)

res = net(batch["data"].permute(0, 2, 1))

for n, r in res.items():
    print(n, r.shape)


layer1 torch.Size([4, 256, 256])
layer2 torch.Size([4, 512, 128])
layer3 torch.Size([4, 1024, 128])
layer4 torch.Size([4, 2048, 128])


In [17]:
net = deeplab.DeepLabNet(
    resnet.ResNetFeatures(
        resnet.Bottleneck,
        resnet.RESNET50_LAYERS,
        return_layers=[resnet.LAYER_1, resnet.LAYER_4],
        replace_stride_with_dilation=[False, True, True],
        num_features=data.NUM_FEATURES,
    ),
    backbone_channels=[256, 2048],
    out_feats=data.NUM_FEATURES,
)
res = net(batch["data"])

res.shape


low torch.Size([4, 256, 256])
high torch.Size([4, 2048, 128])
prj torch.Size([4, 128, 256])
res torch.Size([4, 1280, 128])
feat torch.Size([4, 256, 128])
feat torch.Size([4, 256, 256])
outs torch.Size([4, 256, 256])


torch.Size([4, 1024, 460])

In [13]:
def what():
    _res = []
    for _ in range(4):
        _res.append(torch.rand(4, 256, 128))
        print(_res[-1].size())
    res = torch.cat(_res, dim=1)
    print("res", res.size())

what()

torch.Size([4, 256, 128])
torch.Size([4, 256, 128])
torch.Size([4, 256, 128])
torch.Size([4, 256, 128])
res torch.Size([4, 1024, 128])
