In [None]:
# Imports
# !pip install torchinfo
import torch
import torch.nn as nn
import torch.utils.data as Data
from torchvision import datasets, transforms
from torchinfo import summary
import matplotlib.pyplot as plt
import numpy as np
import random
import os
import time
import TrainerVTS_V07C1 as TP
import DataSetting_v2 as DS

In [None]:
%%html
<style>
.cell-output-ipywidget-background {
    background-color: transparent !important;
}
:root {
    --jp-widgets-color: var(--vscode-editor-foreground);
    --jp-widgets-font-size: var(--vscode-editor-font-size);
}
</style>

### Loader

In [None]:
gpu = 5
date = '240522'
#run = ['30', '100', '300', '900']
run = '30'
exp = 'Prop-Center'

In [None]:
datadir = f'../dataset/0509/make19_{run}-finished/'
data = DS.ModalityLoader(data_dir=datadir, mmap_mode='r')
train = data.profiling({'T01', 'T02'})
test = data.profiling({'T03', 'T04'})
train_set = DS.MyDataset('tv', train)
test_set = DS.MyDataset('test', test)
train_loader, valid_loader = DS.DataSplitter(train_set).split_loader()
test_loader = DS.DataSplitter(test_set, 1).gen_loader()

### Teacher

In [None]:
torch.cuda.set_device(gpu)
imgencoder = TP.ImageEncoder(latent_dim=16)
imgdecoder = TP.ImageDecoder(latent_dim=16)

T_trainer = TP.TeacherTrainer(beta=0.5, 
                              mask=True,
                              recon_lossfunc=nn.MSELoss(reduction='sum'),
                              name='Teacher', networks=[imgencoder, imgdecoder],
                 lr=1e-4, epochs=10, cuda=gpu,
                 train_loader=train_loader, valid_loader=valid_loader, test_loader=test_loader,
                             notion=f"{date}_{run}_{exp}")

In [None]:
### Scheduler
%matplotlib inline
T_trained = T_trainer.schedule()

### Student

#### Train

In [None]:
torch.cuda.set_device(gpu)
csiencoder = TP.CSIEncoder(lstm_steps=225)
centerdecoder = TP.CenterDecoder()
imgencoder = TP.ImageEncoder(latent_dim=16)
imgdecoder = TP.ImageDecoder(latent_dim=16)
imgencoder.load_state_dict(torch.load(f"../saved/240512_30D/240512_30_Teacher_IMGENV07D1@ep167.pth"))
imgdecoder.load_state_dict(torch.load(f"../saved/240512_30D/240512_30_Teacher_IMGDEV07D1@ep167.pth"))

S_trainer = TP.StudentTrainer(name='Student', mask=True,
                              networks=[csiencoder, centerdecoder, imgencoder, imgdecoder],
                              lr=1e-4, epochs=10, cuda=gpu,
                              notion=f"{date}_{run}C",
                              train_loader=train_loader, valid_loader=valid_loader, test_loader=test_loader)

In [None]:
### Scheduler
%matplotlib inline
S_trained = S_trainer.train(autosave=True, notion=f"{date}_{run}C", train_module={'csien', 'ctrde'}, eval_module={'imgen', 'imgde'})
S_trainer.plot_train_loss(autosave=True, notion=f"{date}_{run}C")

In [None]:
S_trainer.test(loader='train')
S_trainer.plot_test(select_num=8, autosave=True, notion=f"{date}_{run}C")

In [None]:
S_trainer.test(loader='test')
S_trainer.plot_test(select_num=8, autosave=True, notion=f"{date}_{run}C")

In [None]:
S_trainer.loss.save('pred', f"{date}_{run}C")