In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import torch
import matplotlib.pyplot as plt

import numpy as np
import pandas as pd

In [83]:
from src.data_loader import ViTDataLoader
from src.vit_train import ViTTrainer, load_model_dir, load_model_config
import importlib
from src.vit_data import load_cell_cycle_data

config_name = 'cell_cycle_24x_chr4'
config = importlib.import_module(f"config.{config_name}")
vit = load_model_config(config)
vit_data = load_cell_cycle_data()


In [84]:
dataloader = ViTDataLoader(vit_data, split_type=config.SPLIT_TYPE, split_arg=config.SPLIT_ARG)
print(dataloader.split_repr())

trainer = ViTTrainer(vit, config_name, dataloader)

Split: chrom,4; Training: 44963; Validation: 4996; Testing: 7592


In [None]:
device = trainer.device
vit = trainer.vit

In [None]:
trainer.setup()

In [None]:
vit.load_state_dict(torch.load('output/complex_24x128_120_20220526_cf53/model.torch', 
                               map_location=torch.device('cpu')))

In [None]:
trainer.compute_predictions_losses()

In [None]:
print(trainer.perf_str)

In [None]:
trainer.plot_predictions()

In [None]:
from src.rna_plotter import load_rna_plotter
from src.orf_plotter import ORFAnnotationPlotter

orf_plotter = ORFAnnotationPlotter()
rna_plotter = load_rna_plotter()


In [None]:
from src.vit_viz import plot_gene_prediction

plot_gene_prediction('MET3', 120.0, trainer.vit, trainer.dataloader.dataset,
    orf_plotter=orf_plotter, rna_plotter=rna_plotter)
0

In [None]:
from src.vit_data import load_cd_data_24x128

vit_data  = load_cd_data_24x128()

In [None]:
loss_df = pd.read_csv('output/complex_24x128_120_20220526_cf53/loss.csv')
loss_df.head()

In [None]:
plt.plot(loss_df.epoch, loss_df.train_loss)
plt.plot(loss_df.epoch, loss_df.debug_train)
plt.show()

plt.plot(loss_df.epoch, loss_df.debug_valid)
plt.plot(loss_df.epoch, loss_df.validation_loss)