In [None]:
# Возможно понадобится
from IPython.display import clear_output
!pip install devito
clear_output

In [None]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
import glob
import pickle
import numpy as np
from matplotlib import pyplot as plt
import random
from tqdm import tqdm
# import devito

In [None]:
torch.cuda.empty_cache()

In [None]:
NUM_CARDS = 3
NUM_DETECTORS = 60
NUM_TIMESTEPS = 512

In [None]:
from torchview import draw_graph
import graphviz

graphviz.set_jupyter_format('png')

# Plotting

In [None]:
def plot_element(xs, y):
    fig, ax = plt.subplots(1, xs.shape[0], figsize=(25, 5))
    for i in range(xs.shape[0]):
        buff = xs[i].T
        if np.any(buff):
            for j in range(buff.shape[1]):
                buff[:, j] = buff[:, j]/np.max(np.abs(buff[:, j]))
        ax[i].imshow(buff, cmap='gray', vmin=-0.05, vmax=0.05)
        ax[i].axis('off')
        ax[i].set_aspect('auto')
    ax[xs.shape[0]//2].plot(y, c='r')
    plt.show()

# Dataset

In [None]:
class GeoDataset(Dataset):
    
    def __init__(self, num_cards, num_detectors, num_timesteps, path, path_noise=None,
                 min_noise_lvl=0.2,
                 max_noise_lvl=1.5,
                ):
        self.num_cards = num_cards
        self.win = self.num_cards//2
        self.num_detectors = num_detectors
        self.num_timesteps = num_timesteps
        self.path = path
        self.length = 0
        self.idx_map = []
        self.elem_rms = []
        self.coll = glob.glob(self.path + "/*.pickle")
        self.map_dataset_params()
        self.path_noise = path_noise
        self.coll_noise = None
        self.length_noise = 0
        self.max_noise_lvl = max_noise_lvl
        self.min_noise_lvl = min_noise_lvl
        if self.path_noise is not None:
            self.coll_noise = glob.glob(self.path_noise + "/*.npz")
            self.length_noise = len(self.coll_noise)
        
        
    def map_dataset_params(self):
        l = 0
        for j, elem in enumerate(tqdm(self.coll)):
            handle = open(elem, 'rb')
            a_dict = pickle.load(handle)
            handle.close()
            for i in range(len(a_dict["data"])):
                self.idx_map.append([j, i])
                self.elem_rms.append(np.sqrt(np.mean(a_dict["data"][i]**2)))
            l += len(a_dict["data"])
        self.length = l
        
    
    def __getitem__(self, idx):
        if idx == 0: # Допилить для win > 3
            out = np.zeros((self.num_cards, 2), dtype='int')-999
            out[1:, :] = np.array(self.idx_map[idx:idx+self.win+1])
        elif idx == self.length: # Допилить для win > 3
            out = np.zeros((self.num_cards, 2), dtype='int')-999
            out[:-1, :] = np.array(self.idx_map[idx-self.win:idx+self.win+1])
        else:
            out = self.idx_map[idx-self.win:idx+self.win+1]
            out = np.array(out)
        elems = out[:,0]
        cards = out[:,1]
        mask = elems == elems[self.win]
        handle = open(self.coll[elems[self.win]], 'rb')
        a_dict = pickle.load(handle)
        handle.close()
        xs = [a_dict["data"][idx] for idx in cards[mask]]
        y = a_dict["targets"][cards[self.win]]
        out = self.transform(xs, y, mask)
        if self.path_noise is not None:
            scalar = (self.min_noise_lvl+random.random()*(self.max_noise_lvl-self.min_noise_lvl))*self.elem_rms[idx]
            return out[0]+ scalar*self.pick_random_noise(), out[1]
        else:
            return out
        
    
    def pick_random_noise(self):
        idx = random.randint(1, self.length_noise-2)
        idx = [idx-1, idx, idx+1]
        noises = []
        roll = np.random.randint(0, self.num_detectors)
        for i in idx:
            buff = np.load(self.coll_noise[i])["noise"].squeeze()
            buff = buff/np.max(np.abs(buff))
            buff = np.roll(buff, roll, axis=0)
            noises.append(buff)
        return torch.tensor(np.array(noises), dtype=torch.float32)
    
    
    def transform(self, xs, y, mask):
        xr = np.zeros((self.num_cards, self.num_detectors, self.num_timesteps))
        indicies = mask.nonzero()[0]
        sub_d = np.round(np.linspace(0, xs[self.win].shape[1]-1, num=self.num_detectors)).astype(int)
        sub_t = np.round(np.linspace(0, xs[self.win].shape[0]-1, num=self.num_timesteps)).astype(int)
        for i in indicies:
            buff = xs[i-indicies[0]].T
            buff = buff[:, sub_t]
            buff = buff[sub_d, :]
            xr[i, :, :] = buff
        
        y = y*(self.num_timesteps/xs[self.win].shape[0])
        y = y[sub_d]
        return torch.tensor(xr, dtype=torch.float32), torch.tensor(y, dtype=torch.float32).unsqueeze(1)
    
    
    def map_rms_lvl(self):
        l = 0
        for j, elem in enumerate(self.coll):
            handle = open(elem, 'rb')
            a_dict = pickle.load(handle)
            handle.close()
            for i in range(len(a_dict["data"])):
                self.elem_rms.append(np.sqrt(np.mean(a_dict["data"][i]**2)))
            l += len(a_dict["data"])
    
    
    def __len__(self):
        return self.length

In [None]:
dataset = GeoDataset(NUM_CARDS, NUM_DETECTORS, NUM_TIMESTEPS,
                     path="/home/andrey/Elastic/Elements",
                     path_noise="/home/andrey/Noise Dataset",
                    )

In [None]:
len(dataset)

In [None]:
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

In [None]:
imgs, targets = next(iter(dataloader))
for i in range(imgs.shape[0]):
    xs = imgs[i].numpy()
    y = targets[i].numpy()
    plot_element(xs, y)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# Model

In [None]:
class ResNeXtBlock(nn.Module):
    
    def __init__(self, num_channels, norm_groups=32, expansion_rate=4):
        super().__init__()
        self.dw_conv = nn.Conv2d(num_channels, num_channels, kernel_size=7, padding=3, groups=num_channels)
        self.group_norm = nn.GroupNorm(norm_groups, num_channels)
        hidden_channels = expansion_rate * num_channels
        self.feed_forward = nn.Sequential(
            nn.Conv2d(num_channels, hidden_channels, kernel_size=1),
            nn.SiLU(),
            nn.Conv2d(hidden_channels, num_channels, kernel_size=1)
        )
        
    def forward(self, x):
        out = self.dw_conv(x)
        out = self.group_norm(out)
        out = self.feed_forward(out)
        x = x + out
        return x

In [None]:
class ResNet1D(nn.Module):
    
    def __init__(
        self,
#         model_channels=256, # Тут я уменьшил, как обсудили в прошлый раз
        model_channels=128,
        num_channels=3,
        groups=32,
        expansion_rate=4,
        dim_mult=(1, 2, 4, 8),
        num_blocks=(3, 3, 3, 3),
    ):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(num_channels, model_channels, kernel_size=(1, 4), stride=(1, 4)),
            nn.GroupNorm(groups, model_channels)
        )
        
        hidden_dims = list(map(lambda mult: model_channels * mult, (1,) + dim_mult))
        in_out_dims = list(zip(hidden_dims[:-1], hidden_dims[1:]))
        self.resnext_blocks = nn.Sequential(*[
            nn.Sequential(
                *[ResNeXtBlock(in_dim, groups, expansion_rate) for _ in range(num_block)],
                nn.GroupNorm(groups, in_dim),
                nn.Conv2d(in_dim, out_dim, kernel_size=(1, 2), stride=(1, 2))
            ) for (in_dim, out_dim), num_block in zip(in_out_dims, num_blocks)
        ])
        
        self.out_layer = nn.Linear(in_out_dims[-1][-1], 1)
        
    def forward(self, x):
        x = self.stem(x)
        x = self.resnext_blocks(x)
        x = x.mean(-1).transpose(-1, -2)
        x = self.out_layer(x)
        return x

In [None]:
model = ResNet1D(num_channels=NUM_CARDS)

In [None]:
from torchinfo import summary

summary(model, input_size=(1, 3, 60, 512), depth=3)

In [None]:
import torchvision
from torchview import draw_graph

model_graph = draw_graph(model, input_size=(1,3,60,512), expand_nested=True, depth=4, graph_dir="TD")
model_graph.visual_graph

# Trainer

In [None]:
class GeoTrainer(pl.LightningModule):

    def __init__(self, num_channels):
        super().__init__()
        self.model = ResNet1D(num_channels=NUM_CARDS)
        self.loss = nn.MSELoss()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=2e-4)
#         lr_scheduler = torch.optim.lr_scheduler.LinearLR(
#             optimizer, start_factor=0.0002, end_factor=1.0, total_iters=5000
#         )
        return [optimizer]
    
    def model_step(self, batch, stage):
        img, target_timesteps = batch
        pred_timesteps = self.model(img)
        loss = self.loss(pred_timesteps, target_timesteps)
        self.log(f'{stage}_loss', loss.detach().cpu().item())
        return loss

    def training_step(self, batch, batch_idx):
        return self.model_step(batch, 'train')

    def validation_step(self, batch, batch_idx):
        return self.model_step(batch, 'valid')

In [None]:
name = "01_augmented_test"
tb_logger = pl.loggers.TensorBoardLogger(name=name, save_dir="./tb_logs_augmented", default_hp_metric=False)
callbacks = [
    pl.callbacks.ModelCheckpoint(
        dirpath="./saved_models/"+name, filename="{step}", monitor="train_loss", mode="min",
        save_top_k=-1, every_n_train_steps=5000
    )
]
trainer = pl.Trainer(
    logger=tb_logger,
    callbacks=callbacks,
#     gpus=1,
    log_every_n_steps=5,
    max_steps=500000,
    gradient_clip_val=1.0,
    gradient_clip_algorithm="value",
    accumulate_grad_batches=4, # Тут запилить нормальные args
)
model = GeoTrainer(NUM_CARDS)

In [None]:
%load_ext tensorboard
%tensorboard --logdir {"./tb_logs_augmented"}

In [None]:
trainer.fit(model, dataloader)

In [None]:
torch.cuda.empty_cache() 

torch.save(model.state_dict(), "baseline_9k_steps.pth")

imgs, gt = next(iter(dataloader))

imgs.shape

print(model)

y_pred = model.model(imgs.to(device))

def plot_element_gt(xs, gt, pr):
    fig, ax = plt.subplots(1, xs.shape[0], figsize=(25, 5))
    for i in range(xs.shape[0]):
        buff = xs[i].T
        if np.any(buff):
            for j in range(buff.shape[1]):
                buff[:, j] = buff[:, j]/np.max(np.abs(buff[:, j]))
        ax[i].imshow(buff, cmap='gray', vmin=-0.05, vmax=0.05)
        ax[i].axis('off')
        ax[i].set_aspect('auto')
    ax[xs.shape[0]//2].plot(gt, c='r')
    ax[xs.shape[0]//2].plot(pr, c='b')
    plt.show()

for i in range(imgs.shape[0]):
    xs = imgs[i].numpy()
    y_gt = gt[i].numpy()
    y_pr = y_pred[i].detach().cpu().numpy()
    plot_element_gt(xs, y_gt, y_pr)