In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import os
import numpy as np
import linecache #fast access to a specific file line
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from einops import rearrange, repeat
import torch.nn.functional as F
import time
import torchinfo
from pathlib import Path

In [2]:
print(torch.__version__)
print(torch.version.cuda)
print(torch.cuda.is_available())

1.11.0
11.1
True


In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [4]:
ALPHABET = ["A", "R", "N", "D", "C", "Q", "E", "G", "H", "I",
            "L", "K", "M", "F", "P", "S", "T", "W", "Y", "V"]

ALPHABET = {ALPHABET[i]:i for i in range(len(ALPHABET))}

ALPHABET['-']= 20
ALPHABET['Z']= 21

print(ALPHABET)

{'A': 0, 'R': 1, 'N': 2, 'D': 3, 'C': 4, 'Q': 5, 'E': 6, 'G': 7, 'H': 8, 'I': 9, 'L': 10, 'K': 11, 'M': 12, 'F': 13, 'P': 14, 'S': 15, 'T': 16, 'W': 17, 'Y': 18, 'V': 19, '-': 20, 'Z': 21}


In [23]:
class MyDataset(Dataset):
    def __init__(self, data_dir, cont_size=6,div=1400000,verbose=False):
        
        self.col_size = 60 #number of column per file (Fasta standard)
        self.data_dir = data_dir #directory of the dataset
        self.cont_size = cont_size
        self.div = div
        self.len = 0  #number of families of sequences (1 per file)
        self.paths = {} #path of each families in the folder
        self.seq_lens = {} #length of each member of the family
        self.seq_nums = {} #number of member of the family
        self.aa_freqs = {} #frequencies of each symbol in the sequence family
        self.p_aa_freqs = {} #frequencies of each symbol in each sequence of a family
        
        
        dir_path = data_dir
        count = 0
        
        # Iterate directory
        for path in os.listdir(dir_path):
            # check if current path is a file
            temp_path = os.path.join(dir_path, path)
            if os.path.isfile(temp_path):
                n = 0 #number of sequences
                p = 0 # used to calculate the length of the sequences
                r = 0 # also used this way

                l = 0 # length of the seq l = p * self.col_size + r 

                cpt = 0 # to detect inconsistencies
                
                with open(temp_path, newline='') as f:
                    first_prot = True
                    newf = True
                    
                    aa_freq = torch.zeros(20)
                    p_aa_freq = torch.zeros(0)
                    
                    #parsing the file
                    line = f.readline()[:-1]
                    while line:
                        cpt += 1
                        if line[0] == '>': #header line
                            if not first_prot:
                                p_aa_freq = torch.cat([p_aa_freq,prot_aa_freq])
                            prot_aa_freq = torch.zeros(1,20)
                            n += 1
                            if newf and not first_prot:
                                newf = False
                            first_prot = False
                                
                        else:# sequence line
                            if newf and len(line) == self.col_size:
                                p += 1

                            if newf and len(line) != self.col_size:
                                r = len(line)
                            for aa in line:
                                aa_id = ALPHABET.get(aa,21)
                                if aa_id < 20:
                                    aa_freq[aa_id] += 1
                                    prot_aa_freq[0][aa_id] += 1

                            assert len(line) == self.col_size or len(line) == r
                        line = f.readline()[:-1]
                    
                    p_aa_freq = torch.cat([p_aa_freq,prot_aa_freq])
                    aa_freq = F.normalize(aa_freq,dim=0,p=1)
                    p_aa_freq = F.normalize(p_aa_freq,dim=1,p=1)

                l = p*self.col_size + r
                
                #sanity check
                #if the file line count is coherent with the number of sequences and their line count
                try: #if r != 0
                    assert (p+2) * n == cpt
                except: #if r == 0
                    assert (p+1) * n == cpt
                    assert r == 0
                    
                
                if n>1: #if this is false, we can't find pairs
                    self.paths[count] = path
                    self.seq_lens[count] = l
                    self.seq_nums[count] = n
                    self.aa_freqs[count] = aa_freq
                    self.p_aa_freqs[count] = p_aa_freq
                    count += 1
                    
                    if verbose and (count % 100 ==0) : print(f"seen = {count}")
            
        self.len = count
    
    def __len__(self):
        return self.len
     
    def sample(self, high, low=0, s=1):
        sample = np.random.choice(high-low, s, replace=False)
        return sample + low
    
    def __getitem__(self, idx, sample_size='auto'): 
        #for each sample:
        #sample i, j st i != j the two sequences to compare
        #sample k the sequence position of the prediction
        #compute the file positions of the 25 + 1 AA
        
        X = []
        y = []
        
        PIDs = []
        local_PIDs = []
        
        pfreqs = []
        local_pfreqs = []
        
        lengths = []
        
        pos = []
        
        
        precomputed_pos = []
        for i in range(-self.cont_size,self.cont_size+1):
            precomputed_pos.append(i)
        for i in range(-self.cont_size,0):
            precomputed_pos.append(i)
        for i in range(1,self.cont_size+1):
            precomputed_pos.append(i)
        
        precomputed_pos = torch.tensor(precomputed_pos).float()
        
        data_path = os.path.join(self.data_dir, self.paths[idx])
        try:
            n = self.seq_nums[idx]
            l = self.seq_lens[idx]
        except:
            print(idx)
            pass
        
        if type(sample_size) != int:
            coef = round((n**2 * l)/self.div) 
            sample_size = max(1,coef)
        
        p = l // self.col_size
        r = l % self.col_size # l = p * q + r
        sequence_line_count = p+2 if r else p+1


        for _ in range(sample_size):
            i,j = self.sample(n,s=2)

            start_i = 2 + (sequence_line_count)*i #start line of protein i
            start_j = 2 + (sequence_line_count)*j #start line of protein j
            
            seq_i = ''
            seq_j = ''
            
            PID_ij = 0
            
            l_ij = 0
            for offset in range(sequence_line_count-1):
                line_i = linecache.getline(data_path, (start_i + offset))[:-1]
                line_j = linecache.getline(data_path, (start_j + offset))[:-1]
                for aa_i, aa_j in zip(line_i,line_j):
                    if aa_i == aa_j:
                        if aa_i != '-':
                            PID_ij += 1
                            seq_i += aa_i
                            seq_j += aa_j        
                    else:
                        seq_i += aa_i
                        seq_j += aa_j
                    
                    if aa_i != '-' and aa_j != '-':
                        l_ij += 1
            
            try:
                PID_ij = PID_ij/l_ij
            except:
                PID_ij = 0
            
            align_l = len(seq_i)
            possible_k = []
            for k,(a_i,a_j) in enumerate(zip(seq_i,seq_j)):   
                if ALPHABET.get(a_i,21) < 20 and ALPHABET.get(a_j,21) < 20:
                    possible_k.append(k)
                    
            try:   
                k = np.random.choice(possible_k)
            except:
                continue
                
            lengths.append(align_l)
            pos_ij = (k + precomputed_pos)
            pos.append(pos_ij)
            
            window_i = ''
            window_j = ''
            
            for w in range(k-self.cont_size,k+self.cont_size+1):
                if w < 0 or w >= align_l: #case of the edges
                    window_i += 'Z'
                    window_j += 'Z'
                else:
                    window_i += seq_i[w]
                    window_j += seq_j[w]
        
            y_j = ALPHABET.get(window_j[self.cont_size], 21) # 'Z' is the default value for rare AA
            X_i = [ALPHABET.get(i, 21) for i in (window_i+window_j[:self.cont_size]+window_j[self.cont_size+1:])]       
            
            X.append(X_i)
            y.append(y_j)
            PIDs.append(PID_ij)
            local_PID_ij = sum(1 for AA1,AA2 in zip(window_i, window_j[:self.cont_size]) if AA1 == AA2 and ALPHABET.get(AA1,21) < 20) \
                         + sum(1 for AA1,AA2 in zip(reversed(window_i), reversed(window_j[self.cont_size+1:])) if AA1 == AA2 and ALPHABET.get(AA1,21) < 20)
            
            loc_comp = sum(1 for AA1,AA2 in zip(window_i, window_j[:self.cont_size]) if ALPHABET.get(AA1,21) < 20 and ALPHABET.get(AA2,21) < 20) \
                         + sum(1 for AA1,AA2 in zip(reversed(window_i), reversed(window_j[self.cont_size+1:])) if ALPHABET.get(AA1,21) < 20 and ALPHABET.get(AA2,21) < 20)
            try:
                tmp = local_PID_ij/loc_comp
            except:
                tmp = 0
                
            local_PIDs.append(tmp)
            pfreqs.append(self.aa_freqs[idx])
            p_i_freqs = self.p_aa_freqs[idx][i]
            p_j_freqs = self.p_aa_freqs[idx][j]
            
            local_pfreqs.append(torch.stack((p_i_freqs,p_j_freqs)))
            
            assert y_j < 20
            assert X_i[self.cont_size] < 20
            
        linecache.clearcache()   
        X = torch.tensor(X)
        try:
            X = F.one_hot(X,22)[:,:,0:-1]
        except RuntimeError:
            pass
        if len(pos) == 0:
            pos = torch.tensor(pos)
        else:
            pos = torch.stack(pos)
        pfreqs = torch.stack(pfreqs)
        local_pfreqs = torch.stack(local_pfreqs)
        X = X.float()
        y = torch.tensor(y)
        PIDs = torch.tensor(PIDs)
        local_PIDs = torch.tensor(local_PIDs)
        lengths = torch.tensor(lengths)
        out = X,y.long(),PIDs,local_PIDs,pfreqs,local_pfreqs,pos,lengths
        return out

    

In [24]:
"""
train_dataset = MyDataset(r"data/train_data",cont_size = 6)
test_dataset = MyDataset(r"data/test_data",cont_size = 6,div=700000)
val_dataset = MyDataset(r"data/val_data",cont_size = 6,div=700000)

fname = 'data/train_dataset1.pth'
savepath = Path(fname)
with savepath.open("wb") as fp:
    torch.save(train_dataset,fp)
    
fname = 'data/test_dataset1.pth'
savepath = Path(fname)
with savepath.open("wb") as fp:
    torch.save(test_dataset,fp)
    
fname = 'data/val_dataset1.pth'
savepath = Path(fname)
with savepath.open("wb") as fp:
    torch.save(val_dataset,fp)
"""

'\ntrain_dataset = MyDataset(r"data/train_data",cont_size = 6)\ntest_dataset = MyDataset(r"data/test_data",cont_size = 6,div=700000)\nval_dataset = MyDataset(r"data/val_data",cont_size = 6,div=700000)\n\nfname = \'data/train_dataset1.pth\'\nsavepath = Path(fname)\nwith savepath.open("wb") as fp:\n    torch.save(train_dataset,fp)\n    \nfname = \'data/test_dataset1.pth\'\nsavepath = Path(fname)\nwith savepath.open("wb") as fp:\n    torch.save(test_dataset,fp)\n    \nfname = \'data/val_dataset1.pth\'\nsavepath = Path(fname)\nwith savepath.open("wb") as fp:\n    torch.save(val_dataset,fp)\n'

In [25]:
fname = 'data/train_dataset1.pth'
savepath = Path(fname)
with savepath.open("rb") as fp:
    train_dataset = torch.load(fp)
    
fname = 'data/test_dataset1.pth'
savepath = Path(fname)
with savepath.open("rb") as fp:
    test_dataset = torch.load(fp)
    
fname = 'data/val_dataset1.pth'
savepath = Path(fname)
with savepath.open("rb") as fp:
    val_dataset = torch.load(fp)
    
print(len(train_dataset))
print(len(test_dataset))
print(len(val_dataset))

13219
2838
2826


In [26]:
def my_collate(batch):
    data = torch.cat([item[0] for item in batch],dim=0)
    target = torch.cat([item[1] for item in batch],dim=0)
    PID = torch.cat([item[2] for item in batch],dim=0)
    lPID = torch.cat([item[3] for item in batch],dim=0)
    pfreqs = torch.cat([item[4] for item in batch],dim=0)
    lpfreqs = torch.cat([item[5] for item in batch],dim=0)
    pos = torch.cat([item[6] for item in batch],dim=0)
    length = torch.cat([item[7] for item in batch],dim=0)
    return data, target, PID, lPID,pfreqs,lpfreqs, pos, length

In [27]:
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True,collate_fn=my_collate,num_workers=4)

test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True,collate_fn=my_collate,num_workers=4)

val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=True,collate_fn=my_collate,num_workers=4)

In [28]:
class AttBlock(nn.Module):
    def __init__(self,in_features,out_features=None,num_heads=8,head_dims=24):
        super().__init__()
        out_features = out_features or in_features
        
        self.Q_w = nn.Linear(in_features,num_heads*head_dims,bias=False)
        self.K_w = nn.Linear(in_features,num_heads*head_dims,bias=False)
        self.V_w = nn.Linear(in_features,num_heads*head_dims,bias=False)
        
        self.att = nn.MultiheadAttention(num_heads*head_dims,num_heads=num_heads,batch_first=True)
        self.lin = nn.Linear(num_heads*head_dims,out_features)
        
    def forward(self,x):
        Q = self.Q_w(x)
        K = self.K_w(x)
        V = self.V_w(x)
        out,_ = self.att(Q,K,V,need_weights=False)
        out = self.lin(out)
        
        return out

In [29]:
class FeedFoward(nn.Module):
    def __init__(self,in_features,out_features=None,wide_factor=4):
        super().__init__()
        out_features = out_features or in_features
        hidden_dim = wide_factor * in_features
        
        self.lin1 = nn.Linear(in_features,hidden_dim)
        self.act1 = nn.GELU()
        self.lin2 = nn.Linear(hidden_dim,out_features)
        self.act2 = nn.GELU()
        
    def forward(self,x):
        out = self.lin1(x)
        out = self.act1(out)
        out = self.lin2(out)
        out = self.act2(out)
        
        return out

In [30]:
def drop_path(x, drop_prob: float = 0.1, training: bool = False, scale_by_keep: bool = True):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.

    """
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
    if keep_prob > 0.0 and scale_by_keep:
        random_tensor.div_(keep_prob)
    return x * random_tensor


class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob=None, scale_by_keep=True):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
        self.scale_by_keep = scale_by_keep

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)


In [31]:
class Block(nn.Module):
    def __init__(self,in_features,num_heads=8,head_dims=24,wide_factor=4,drop=0.1):
        super().__init__()
        
        self.att_block = AttBlock(in_features,num_heads=num_heads,head_dims=head_dims)
        self.ff = FeedFoward(in_features,wide_factor=wide_factor)
        self.drop_path = DropPath(drop)
        
        self.norm1 = nn.LayerNorm(in_features)
        self.norm2 = nn.LayerNorm(in_features)
    def forward(self,x):
        out = x + self.drop_path(self.att_block(x))
        out = self.norm1(out)
        out = out + self.drop_path(self.ff(out))
        out = self.norm2(out)
        return out

In [32]:
class Classifier_Head(nn.Module):
    def __init__(self,in_features,clf_dims,out_size,seq_len):
        super().__init__()
        in_dim = seq_len*in_features
        self.in_dim = in_dim
        
        layers = []
        for out_dim in clf_dims:
            layers.append(nn.Linear(in_dim,out_dim))
            layers.append(nn.GELU())
            layers.append(nn.Dropout(p=0.2))
            in_dim = out_dim

        layers.append(nn.Linear(in_dim,out_size))
        
        self.clf = nn.Sequential(*layers)
        
    def forward(self,x):
        out = x.reshape((-1,self.in_dim))
        
        out = self.clf(out)
        
        return out

In [37]:
class AttNet(nn.Module):
    def __init__(self,in_features,num_heads,head_dims,wide_factors,drops,input_dim=21,out_size=20,seq_len=30,clf_dims=[256,64],cont_size=6):
        super().__init__()
        
        blocks = []
        for n_h, h_d,w,d in zip(num_heads,head_dims,wide_factors,drops):
            blocks.append(Block(in_features,num_heads=n_h,head_dims=h_d,wide_factor=w,drop=d))
        self.feature_extractor = nn.Sequential(*blocks)
        self.in_features = in_features
        self.input_dim = input_dim
        self.clf = Classifier_Head(in_features,clf_dims,out_size=out_size,seq_len=seq_len)
        
        self.cont_size=cont_size
        
        sp = Path("data/freq.pth")
        with sp.open("rb") as fp:
            self.F = nn.Parameter(torch.log(torch.load(fp)))
            
        pid_layers = [nn.Linear(1,in_features),nn.Sigmoid()]
        self.pid_l = nn.Sequential(*pid_layers)
    
    def to_input(self,x,PID,pfreqs,lpfreqs,pos,length):
        X_idx = torch.argmax(x[:,self.cont_size],dim=1)
        seq1 = x[:,:2*self.cont_size+1]
        y_freq = F.pad(F.softmax(self.F[X_idx],dim=1).unsqueeze(1), pad=(0, 1), mode='constant', value=0) 
        seq2 = torch.cat((x[:,2*self.cont_size+1:3*self.cont_size+1],y_freq,x[:,3*self.cont_size+1:]),dim=1)
        aa_pos = pos[:,:2*self.cont_size+1]/length.unsqueeze(1)
        aa_pos = aa_pos.unsqueeze(2)
        pos_dim = (self.in_features-self.input_dim-1)//2
        
        for i in range(pos_dim): #positionnal_encoding
            p = torch.cos(pos[:,:2*self.cont_size+1]/(4**(2*i/pos_dim))).unsqueeze(2)
            ip = torch.sin(pos[:,:2*self.cont_size+1]/(4**(2*i/pos_dim))).unsqueeze(2)
            aa_pos = torch.cat([aa_pos,p,ip],dim=2)

        seq1 = torch.cat([seq1,aa_pos],dim=2)
        seq2 = torch.cat([seq2,aa_pos],dim=2)
        X = torch.cat([seq1,seq2],dim=1)
        
        pad = (0,self.in_features-self.input_dim+1)
        pf = F.pad(pfreqs.unsqueeze(1),pad =pad, mode='constant', value=0)
        pid = self.pid_l(PID.unsqueeze(1)).unsqueeze(1)
        lpf = F.pad(lpfreqs,pad=pad, mode='constant', value=0)

        X_input = torch.cat([X,pf,lpf,pid],dim=1)
        
        return X_input
    
    def forward(self,x,PID,pfreqs,lpfreqs,pos,length):
        X_input = self.to_input(x,PID,pfreqs,lpfreqs,pos,length)
        features = self.feature_extractor(X_input)
        out = self.clf(features)
        return out

In [38]:
class State:
    def __init__(self,model,optim,scheduler):
        self.model = model
        self.optim = optim
        self.scheduler = scheduler
        self.epoch = 0

In [63]:
def eval_model(fname,train,test,val,N=10):
    savepath = Path(fname)
    with savepath.open("rb") as fp:
        state = torch.load(fp)
        
    state.model.eval()
    with torch.no_grad():
        print('EVALUATING ON TRAIN DATA : ')
        eval_losses = []
        for _ in range(N):
            for X,y, PID, lPID,pfreqs,lpfreqs,pos,length in train:
                X = X.to(device)
                y = y.to(device)
                PID = PID.to(device)
                pos = pos.to(device)
                pfreqs = pfreqs.to(device)
                lpfreqs = lpfreqs.to(device)
                length = length.to(device)

                y_hat = state.model(X,PID,pfreqs,lpfreqs,pos,length)
                y_hat = F.softmax(y_hat,dim=1)

                y = F.one_hot(y, 20)
                eval_l = (y_hat-y)**2
                eval_losses.append(torch.sum(eval_l,dim=1).detach().cpu())

        score_train = torch.mean(torch.cat(eval_losses)).detach().cpu().item()
        print(f"{score_train = }\n")

        print('EVALUATING ON TEST DATA : ')
        eval_losses = []
        for _ in range(N):
            for X,y, PID, lPID,pfreqs,lpfreqs, pos,length in test:
                X = X.to(device)
                y = y.to(device)
                PID = PID.to(device)
                pos = pos.to(device)
                pfreqs = pfreqs.to(device)
                lpfreqs = lpfreqs.to(device)
                length = length.to(device)

                y_hat = state.model(X,PID,pfreqs,lpfreqs,pos,length)
                y_hat = F.softmax(y_hat,dim=1)

                y = F.one_hot(y, 20)
                eval_l = (y_hat-y)**2
                eval_losses.append(torch.sum(eval_l,dim=1).detach().cpu())
    
        score_test = torch.mean(torch.cat(eval_losses)).detach().cpu().item()
        print(f"{score_test = }\n")

        print('EVALUATING ON VAL DATA : ')
        eval_losses = []
        for _ in range(N):
            for X,y, PID, lPID,pfreqs,lpfreqs, pos, length in val:
                X = X.to(device)
                y = y.to(device)
                PID = PID.to(device)
                lPID = lPID.to(device)
                pos = pos.to(device)
                pfreqs = pfreqs.to(device)
                lpfreqs = lpfreqs.to(device)
                length = length.to(device)

                y_hat = state.model(X,PID,pfreqs,lpfreqs,pos,length)
                y_hat = F.softmax(y_hat,dim=1)

                y = F.one_hot(y, 20)
                eval_l = (y_hat-y)**2
                eval_losses.append(torch.sum(eval_l,dim=1).detach().cpu())
    
        score_val = torch.mean(torch.cat(eval_losses)).detach().cpu().item()
        print(f"{score_val = }\n")
    
    return score_train, score_test, score_val

In [64]:
def main(train_loader,val_loader,epochs=101,fname="models/state.pth",fnameb=None,state=None,last_epoch_sched=float('inf'),use_mut=True):
    
    #to get the best model
    best = float('inf')
    
    #getting the acceleration device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    #loading from previous checkpoint
    if fnameb is None:
        fnameb = fname[:-4] + '_best' +fname[-4:]
        
    savepath = Path(fname)
    if savepath.is_file():
        with savepath.open("rb") as fp:
            state = torch.load(fp)
    else:
        if state is None:
            model = AttNet(22,[8,8,8],[24,24,24],[4,4,4],[0.1,0.1,0.1])
            model = model.to(device)
            optim = torch.optim.AdamW(model.parameters(),lr = 0.0001)
            scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optim,16,2)
            state = State(model,optim,scheduler)
    
    
    Loss = nn.CrossEntropyLoss(reduction='sum')
    LossMut = nn.BCELoss(reduction='sum')
    EvalLoss = nn.MSELoss(reduction='mean')
    
    #for logs
    List_Loss = []
    Eval_Loss = []
    for epoch in range(state.epoch, epochs):
        batch_losses = []
        state.model.train()
        for X,y, PID, lPID,pfreqs,lpfreqs,pos,length in train_loader:
            X = X.to(device)
            y = y.to(device)
            PID = PID.to(device)
            pos = pos.to(device)
            pfreqs = pfreqs.to(device)
            lpfreqs = lpfreqs.to(device)
            length = length.to(device)
            
            state.optim.zero_grad()
            y_hat = state.model(X,PID,pfreqs,lpfreqs,pos,length)
            

            if use_mut:
                X_idx = torch.argmax(X[:,6],dim=1)
                y_true = (X_idx == y).float().unsqueeze(1)  #0 if a mutation happens else 1 
                y_pred = F.softmax(y_hat,dim=1)
                y_pred = y_pred.gather(1,X_idx.view(-1,1)) #the Xth component of y_hat should be predicting ^
                l = (Loss(y_hat,y) + LossMut(y_pred,y_true))/311
            else:
                l = Loss(y_hat,y)/311
            l.backward()
            state.optim.step()
            
            
            
            batch_losses.append(l.detach().cpu())
        List_Loss.append(torch.mean(torch.stack(batch_losses)).detach().cpu())
        state.epoch = epoch + 1
        if epoch < last_epoch_sched:
            state.scheduler.step()
        
        savepath = Path(fname)
        with savepath.open("wb") as fp:
            torch.save(state,fp)
        
        with torch.no_grad():
            eval_losses = [] 
            state.model.eval()
            for X,y, PID, lPID,pfreqs,lpfreqs, pos,length in val_loader:

                X = X.to(device)
                y = y.to(device)
                PID = PID.to(device)
                pos = pos.to(device)
                pfreqs = pfreqs.to(device)
                lpfreqs = lpfreqs.to(device)
                length = length.to(device)

                y_hat = state.model(X,PID,pfreqs,lpfreqs,pos,length)
                y_hat = F.softmax(y_hat,dim=1)

                y = F.one_hot(y, 20)
                eval_l = (y_hat-y)**2
                eval_losses.append(torch.sum(eval_l,dim=1).detach().cpu())
    
            score = torch.mean(torch.cat(eval_losses)).detach().cpu().item()
            Eval_Loss.append(score)
        
        if score < best :
            best = score
            savepath = Path(fnameb)
            with savepath.open("wb") as fp:
                torch.save(state,fp)
        
        print(f"epoch n°{epoch} : train_loss = {List_Loss[-1]}, val_loss = {Eval_Loss[-1]}") 


        
    return List_Loss,Eval_Loss,state

In [65]:
def get_params(input_size,N,head,head_dim,wide_factor,drop_prob):
    return input_size, [head for _ in range(N)], [head_dim for _ in range(N)], [wide_factor for _ in range(N)], [drop_prob for _ in range(N)], 

In [66]:
params = get_params(32,4,8,32,4,0.1)
model = AttNet(*params,clf_dims=[1024,256,64])
model = model.to(device)
torchinfo.summary(model,[(1,25,21),(1,),(1,20),(1,2,20),(1,25),(1,)])

Layer (type:depth-idx)                        Output Shape              Param #
AttNet                                        [1, 20]                   400
├─Sequential: 1-1                             [1, 32]                   --
│    └─Linear: 2-1                            [1, 32]                   64
│    └─Sigmoid: 2-2                           [1, 32]                   --
├─Sequential: 1-2                             [1, 30, 32]               --
│    └─Block: 2-3                             [1, 30, 32]               --
│    │    └─AttBlock: 3-1                     [1, 30, 32]               295,968
│    │    └─DropPath: 3-2                     [1, 30, 32]               --
│    │    └─LayerNorm: 3-3                    [1, 30, 32]               64
│    │    └─FeedFoward: 3-4                   [1, 30, 32]               8,352
│    │    └─DropPath: 3-5                     [1, 30, 32]               --
│    │    └─LayerNorm: 3-6                    [1, 30, 32]               64
│    └─Bloc

In [None]:
params = get_params(32,4,8,32,4,0.1)
model = AttNet(*params,clf_dims=[1024,256,64])
model = model.to(device)
optim = torch.optim.AdamW(model.parameters(),lr = 0.0003)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optim,16,2)
state = State(model,optim,scheduler)

fname = "models/stateATT_4L_fixed.pth" 
start = time.time()
Train,Eval,_ = main(train_dataloader, val_dataloader,fname=fname,epochs=4080,state=state,use_mut=False)
stop = time.time()
print(stop-start)

epoch n°0 : train_loss = 2.7820186614990234, val_loss = 0.8786128163337708
epoch n°1 : train_loss = 2.496732711791992, val_loss = 0.8379143476486206
epoch n°2 : train_loss = 2.396580696105957, val_loss = 0.8233270645141602
epoch n°3 : train_loss = 2.3632419109344482, val_loss = 0.8242983222007751
epoch n°4 : train_loss = 2.341736316680908, val_loss = 0.8151569366455078
epoch n°5 : train_loss = 2.330084800720215, val_loss = 0.8113887906074524
epoch n°6 : train_loss = 2.3223018646240234, val_loss = 0.8084423542022705
epoch n°7 : train_loss = 2.3109476566314697, val_loss = 0.8103057146072388
epoch n°8 : train_loss = 2.3055741786956787, val_loss = 0.8092596530914307
epoch n°9 : train_loss = 2.294175386428833, val_loss = 0.8076428771018982
epoch n°10 : train_loss = 2.294893264770508, val_loss = 0.8058606386184692
epoch n°11 : train_loss = 2.2987349033355713, val_loss = 0.8114710450172424
epoch n°12 : train_loss = 2.287438154220581, val_loss = 0.8090648651123047
epoch n°13 : train_loss = 2.2

epoch n°109 : train_loss = 2.194643020629883, val_loss = 0.7960415482521057
epoch n°110 : train_loss = 2.2085750102996826, val_loss = 0.7987204790115356
epoch n°111 : train_loss = 2.1956968307495117, val_loss = 0.795861005783081
epoch n°112 : train_loss = 2.2203705310821533, val_loss = 0.7983351945877075
epoch n°113 : train_loss = 2.2149548530578613, val_loss = 0.8011788129806519
epoch n°114 : train_loss = 2.2284748554229736, val_loss = 0.8026575446128845
epoch n°115 : train_loss = 2.2205591201782227, val_loss = 0.8040472269058228
epoch n°116 : train_loss = 2.224886178970337, val_loss = 0.799980103969574
epoch n°117 : train_loss = 2.2168514728546143, val_loss = 0.8045084476470947
epoch n°118 : train_loss = 2.215970039367676, val_loss = 0.800125002861023
epoch n°119 : train_loss = 2.2299087047576904, val_loss = 0.7986940741539001
epoch n°120 : train_loss = 2.2166528701782227, val_loss = 0.7978545427322388
epoch n°121 : train_loss = 2.2222299575805664, val_loss = 0.8009495735168457
epoch

In [None]:
plt.plot(np.arange(len(Train)),Train)

In [None]:
plt.plot(np.arange(len(Eval)),Eval)

In [None]:
fname = "models/stateATT_4L_fixed.pth" 
SCORES = eval_model(fname,train_dataloader,test_dataloader,val_dataloader)

In [None]:
fname = "models/stateATT_4L_fixed_best.pth" 
SCORES = eval_model(fname,train_dataloader,test_dataloader,val_dataloader)

In [None]:
params = get_params(32,5,8,32,4,0.1)
model = AttNet(*params,clf_dims=[1024,256,64])
model = model.to(device)
optim = torch.optim.AdamW(model.parameters(),lr = 0.0003)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optim,16,2)
state = State(model,optim,scheduler)

fname = "models/stateATT_5L_fixed.pth" 
start = time.time()
Train,Eval,_ = main(train_dataloader, val_dataloader,fname=fname,epochs=4080,state=state,use_mut=False)
stop = time.time()
print(stop-start)

In [None]:
plt.plot(np.arange(len(Train)),Train)

In [None]:
plt.plot(np.arange(len(Eval)),Eval)

In [None]:
fname = "models/stateATT_5L_fixed.pth" 
SCORES = eval_model(fname,train_dataloader,test_dataloader,val_dataloader)

In [None]:
fname = "models/stateATT_5L_fixed_best.pth" 
SCORES = eval_model(fname,train_dataloader,test_dataloader,val_dataloader)

In [None]:
params = get_params(32,6,8,32,4,0.1)
model = AttNet(*params,clf_dims=[1024,256,64])
model = model.to(device)
optim = torch.optim.AdamW(model.parameters(),lr = 0.0003)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optim,16,2)
state = State(model,optim,scheduler)

fname = "models/stateATT_6L_fixed.pth" 
start = time.time()
Train,Eval,_ = main(train_dataloader, val_dataloader,fname=fname,epochs=4080,state=state,use_mut=False)
stop = time.time()
print(stop-start)

In [None]:
plt.plot(np.arange(len(Train)),Train)

In [None]:
plt.plot(np.arange(len(Eval)),Eval)

In [None]:
fname = "models/stateATT_6L_fixed.pth" 
SCORES = eval_model(fname,train_dataloader,test_dataloader,val_dataloader)

In [None]:
fname = "models/stateATT_6L_fixed_best.pth" 
SCORES = eval_model(fname,train_dataloader,test_dataloader,val_dataloader)