In [1]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
import importlib
import math

from tqdm.auto import tqdm as tqdm_auto
from tqdm.notebook import tqdm



import torch
import torch.cuda
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim
import torch.nn.functional as F 
from IPython.display import clear_output

In [2]:
import logging
import io

logger = logging.getLogger('basic_logger')
logger.setLevel(logging.DEBUG)
log_capture_string = io.StringIO()
ch = logging.StreamHandler(log_capture_string)
ch.setLevel(logging.DEBUG)
logger.addHandler(ch)

In [3]:
import dataset_regression as dataset
importlib.reload(dataset)

import legnet_difgenerator

In [16]:
N = 'full' # load full dataset
mut = (0,400) # limit of mutations from 0 to 300
epochs = 300
batch_size = 526
num_workers = 16
lr = 0.001
device = torch.device("cuda:1") #if torch.cuda.is_available() else "cpu"
cell_type_filter = 'c17'

In [17]:
PATH_FROM = '../../../data/UTR3_zinb_norm_singleref_2023-05-23.csv'
df = pd.read_csv(PATH_FROM)
df.cell_type.unique()

array(['c1', 'c13', 'c17', 'c2', 'c4', 'c6'], dtype=object)

In [18]:
df = df[df.cell_type == cell_type_filter].reset_index(drop=True)
scores = (df['1']*1+df['2']*2+df['3']*3+df['4']*4) / df[['1', '2', '3', '4']].sum(axis=1)
df['mass_center'] = scores

In [19]:
my_df = dataset.PromotersData(df, limits=mut)
my_df.data

Unnamed: 0,seq,cell_type,replicate,1,2,3,4,fold,mass_center
0,TGCAGTTTTGACCTCCCAGGCTCAAGCGATCCTCCTGCCTCAGCCT...,c17,1,21.945857,36.076924,62.723318,25.114635,val,2.623929
1,TGCAGTTTTGACCTCCCAGGCTCAAGCGATCCTCCTGCCTCAGCCT...,c17,2,19.848040,34.344834,35.704787,42.941895,val,2.765890
2,ATCAAAAAGCAGGCCAGATTCTAATCAAAATCAGGTAAATTTTAAT...,c17,1,24.996422,38.172301,39.674662,59.162970,train,2.820981
3,ATCAAAAAGCAGGCCAGATTCTAATCAAAATCAGGTAAATTTTAAT...,c17,2,28.542939,24.460446,42.991478,53.016374,train,2.808538
4,ATTTTAGTTTGCCCAAATAATATCTTGAAAATGCTCTGAATTTTAC...,c17,1,7.785771,4.782926,12.253139,28.715463,val,3.156171
...,...,...,...,...,...,...,...,...,...
56847,TGTGCTTCCTAAGAGTACAAACCTGAGCATATGTCCAGGCTTGCAA...,c17,2,23.535406,25.189248,31.332772,45.950563,train,2.791208
56848,TAGGTGGTGATCTTAAATGGGTGAGATGGAACGAGAGCACACATTA...,c17,1,19.259539,30.975137,20.634468,16.226516,train,2.388400
56849,TAGGTGGTGATCTTAAATGGGTGAGATGGAACGAGAGCACACATTA...,c17,2,18.755488,31.065221,28.372554,14.724238,train,2.420433
56850,AGGAGGCAACTGTGGCATTGCTTCCTTAACCAGCTCATGGTGTGTG...,c17,1,12.384384,33.298272,34.982940,33.136733,train,2.780933


In [21]:
def initialize_weights(m):
    if isinstance(m, nn.Conv1d):
        n = m.kernel_size[0] * m.out_channels
        m.weight.data.normal_(0, math.sqrt(2 / n))
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm1d):
        nn.init.constant_(m.weight.data, 1)
        nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.Linear):
        m.weight.data.normal_(0, 0.001)
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0)
            

In [9]:
generator = legnet_difgenerator.LegNet_diffusion(240,
                      ks=7,
                block_sizes=[256, 128, 128, 64, 64, 64, 64],
                final_ch=4).to(device)

generator.apply(initialize_weights)
clear_output()


In [11]:
train_set, val_set = torch.utils.data.random_split(my_df, [0.8, 0.2])

dl_train = DataLoader(train_set,
                      batch_size=batch_size,
                      num_workers=num_workers,
                      shuffle=True,
                      worker_init_fn = lambda id: np.random.seed(id)
                     )
# dl_train = DataloaderWrapper(dl_train, batch_per_epoch=batch_per_epoch)
dl_test = DataLoader(val_set,
                     batch_size=batch_size,
                     num_workers=num_workers,
                     shuffle=False,
                     worker_init_fn = lambda id: np.random.seed(id)
                    )
# dl_test = DataloaderWrapper(dl_test, batch_per_epoch=batch_per_epoch)

optimizer = torch.optim.AdamW(generator.parameters(), lr=lr)
seq_criterion = nn.CrossEntropyLoss()
nucl_criterion = nn.KLDivLoss(reduction= 'batchmean')
score_criterion=nn.MSELoss()


In [12]:
class Trainer:
    def __init__(self,
            model: torch.nn.Module, 
            # pretraned_model: torch.nn.Module,
            train_dataloader: torch.utils.data.DataLoader ,
            test_dataloader: torch.utils.data.DataLoader ,
            seq_criterion: torch.nn.CrossEntropyLoss,
            nucl_criterion: torch.nn.KLDivLoss,
            optimizer: torch.optim.Optimizer,
            epochs: int,
            cell_type_filter,
            batch_size: int = 1024,
            batch_per_epoch: int = 1000,
            device = torch.device("cuda:0"),
            num:int = 1
            ):
        self.optimizer = optimizer
        self.seq_criterion = seq_criterion
        self.nucl_criterion = nucl_criterion
        self.model = model
        self.train_dl = train_dataloader
        self.test_dl = test_dataloader
        self.epochs = epochs
        self.batch_per_epoch = batch_per_epoch
        self.device = device
        self.batch_size = batch_size
        self.mean_nuc_train = []
        self.mean_nuc_val = []
        self.score = []
        self.cell_type_filter = cell_type_filter
        self.num = num
            
    def train(self, epoch):
        print(f'start training, epoch = {epoch}')
        self.model.train()
        ltr = []
        ltr_nucl = []
        for data in tqdm(self.train_dl, mininterval=60):
            
            target_seq, mutated_seq, _ = data
            target_seq, mutated_seq = target_seq.float().to(self.device), mutated_seq.float().to(self.device) 
            seq_len = target_seq.shape[-1]
            target_nucl = mutated_seq[:,6:,1]
            pred = self.model(mutated_seq)
            loss_seq = self.seq_criterion(pred, target_seq)
            pred_seq = torch.softmax(pred, dim=1)
            
            pred_seq_nucl = torch.sum(pred_seq.detach(),  dim=2)/seq_len
            
            # print(pred_seq_nucl, target_nucl)
            # loss_nucl = self.nucl_criterion(torch.log_softmax(pred_seq_nucl, dim=1), torch.softmax(target_nucl, dim=1))
            loss_nucl = self.nucl_criterion(torch.log(pred_seq_nucl), target_nucl)
            
            ltr.append(loss_seq.item())
            ltr_nucl.append(loss_nucl.item())
            total_loss = loss_seq + loss_nucl*self.num
            
            total_loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad() 
            
        self.mean_nuc_train.append(np.mean(ltr_nucl))
        mean_loss = np.mean(ltr)
        return mean_loss
    
    def validate(self, epoch):
        print(f'start validating, epoch = {epoch}')
        with torch.no_grad():
            self.model.eval()
            lte = []
            lte_nucl = []
            score_losses = []
            score_cores = []
                        
            for data in tqdm(self.test_dl, mininterval=60):
                target_seq_val, mutated_seq_val, _ = data
                mutated_seq_val, target_seq_val = mutated_seq_val.float().to(self.device), target_seq_val.float().to(self.device)

                target_score = mutated_seq_val[:,4,1].clone()
                mutated_seq = mutated_seq_val[:,:4,:].clone()
                target_nucl = mutated_seq_val[:,6:,1].clone()
                seq_len = target_seq_val.shape[-1]

                pred = self.model(mutated_seq_val)
                pred_seq = torch.softmax(pred, dim=1)      
                loss_seq = self.seq_criterion(pred_seq, target_seq_val)
                pred_seq_nucl = torch.sum(pred_seq.detach(),  dim=2)/seq_len
                
                # print(pred_seq_nucl, target_nucl)
                # loss_nucl = self.nucl_criterion(torch.log_softmax(pred_seq_nucl, dim=1), torch.softmax(target_nucl, dim=1))
                loss_nucl = self.nucl_criterion(torch.log(pred_seq_nucl), target_nucl)
                
                lte.append(loss_seq.item())
                lte_nucl.append(loss_nucl.item())
                
                lte.append(loss_seq.item())
                
            self.mean_nuc_val.append(np.mean(lte_nucl))
            mean_loss_val = np.mean(lte)
            return mean_loss_val
       
        
    def training(self):
        
        self.save_dir = f"./saved_model/utr3/model_epochs_{self.epochs}_cell_type_{self.cell_type_filter}_{self.num}/"
        os.makedirs(self.save_dir, exist_ok=True)
        train_losses = []
        test_losses = []
        for epoch in tqdm(range(self.epochs)):
            tr_loss = self.train(epoch)
            train_losses.append(tr_loss)
  
            test_loss = self.validate(epoch)
            test_losses.append(test_loss)

            self.plotter(train_losses,test_losses, epoch)
            self.save_model(epoch,train_losses)
        return train_losses, test_losses #, self.score
    

    def plotter(self, loss_train, loss_val, epoch):
        fig, (ax1, ax2, ax3, ax4, ax5) = plt.subplots(5, 1,  figsize=(7, 7))
        
        ax1.plot(loss_train, color='red')
        ax3.plot(loss_train, color='red')
        ax2.plot(loss_val, color='blue')
        ax3.plot(loss_val, color='blue')
        ax4.plot(self.mean_nuc_train, color = 'black')
        ax5.plot(self.mean_nuc_val, color = 'black')
        ax1.grid(axis='x')
        ax2.grid(axis='x')
        ax3.grid(axis='x')
        ax2.set_xlabel('Epoch')
        ax1.set_ylabel('Train Loss')
        ax3.set_ylabel('Train and val Loss')
        ax2.set_ylabel('Val Loss')
        ax4.set_ylabel('Train KL Loss')
        ax5.set_ylabel('Val KL Loss')
        
        suptitle_string = f'epoch={epoch}'
        fig.suptitle(suptitle_string, y=1.05, fontsize=10)

        pic_test_name = os.path.join(self.save_dir, f"lossestrainandtest_epoch={epoch}.png")
        plt.tight_layout()
        fig.savefig(pic_test_name)
        fig.show()
        np.save(f'./saved_model/utr3/model_epochs_{self.epochs}_cell_type_{self.cell_type_filter}_{self.num}/train_loss.npy', np.array(loss_train))
        np.save(f'./saved_model/utr3/model_epochs_{self.epochs}_cell_type_{self.cell_type_filter}_{self.num}/test_loss.npy', np.array(loss_val))

            
    def save_model(self, epoch, losseshist):
        PATH = os.path.join(self.save_dir, f"model_{epoch}.pth")
            
        torch.save({
            'epoch' : epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'loss': losseshist
            }, PATH)

        print(f'---------------  SAVED MODEL {PATH}-------------------')

In [None]:
KL_num = 1
train = Trainer(
        model=generator, 
        # pretraned_model=predictor,
        train_dataloader=dl_train,
        test_dataloader=dl_test,
        seq_criterion=seq_criterion,
        nucl_criterion=nucl_criterion,
        optimizer=optimizer,
        epochs=epochs,
        batch_size=batch_size,
        # batch_per_epoch=batch_per_epoch,
        device = device,
        cell_type_filter = cell_type_filter,
        num = KL_num
        )
try:
    train_losses, test_losses = train.training()
except Exception as e:
    logger.exception("Training failed")

-------------------------------------------------------


In [14]:
log_contents = log_capture_string.getvalue()
print(log_contents)


