In [1]:
import polars as pl
import pandas as pd
import numpy as np

from settings import gen_dataset

from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight

import sys
import os
import json

# Playing around .json formats

In [2]:
data_path = '/home/onyxia/work/HierarchicProtLM/data/'

In [3]:
train, validation, test = pd.read_json(data_path + 'ECPred40_train.json').drop('index',axis=1).rename(columns={"sequence": "AA_seq"}), pd.read_json(data_path + 'ECPred40_valid.json').rename(columns={"sequence": "AA_seq"}), pd.read_json(data_path + 'ECPred40_test.json').rename(columns={"sequence": "AA_seq"})

In [4]:
train

Unnamed: 0,Protein UniProt Acc.,EC Number,AA_seq
0,Q65GK1,2.5.1.61,MRNIIVGSRRSKLAMTQTKWVIKKLEELNPDFTFEIKEIVTKGDRI...
1,P16616,2.5.1.61,MMRTIKVGSRRSKLAMTQTKWVIQKLKEINPSFAFEIKEIVTKGDR...
2,Q1LU25,2.5.1.61,MLNNILKIATRQSPLAIWQANYVRNQLLSFYPTLLIELVPIVTSGD...
3,Q7VRM4,2.5.1.61,MQAKILRIATRKSPLAICQACYVCNKLKHYHPHIQTELIPIITTGD...
4,Q491Z6,2.5.1.61,MKNKILKIATRKSQLAICQAQYVHNELKHYHPTLSIELMPIVTTGD...
...,...,...,...
258022,Q8R121,0.0.0.0,MRVASSLFLPVLLTEVWLVTSFNLSSHSPEASVHLESQDYENQTWE...
258023,Q3URR7,0.0.0.0,MLAEPVPDALEQEHPGAVKLEEDEVGEEDPRLAESRPRPEVAHQLF...
258024,P54479,0.0.0.0,MNVQEALNLLKENGYKYTNKREDMLQLFADSDRYLTAKNVLSALND...
258025,Q9VA00,0.0.0.0,MSASANLANVYAELMRRCGESYTITYGAPPTYLVSMVGAAEAGKKI...


In [5]:
def split_and_create_columns(row):
    numbers = row['EC Number'].split('.')
    return pd.Series([numbers[0], '.'.join(numbers[:2]), '.'.join(numbers[:3]), '.'.join(numbers[:4])])

In [6]:
train[['ec_first_cat', 'ec_second_cat', 'ec_third_cat', 'ec_fourth_cat']] = train.apply(split_and_create_columns, axis=1)

In [7]:
train

Unnamed: 0,Protein UniProt Acc.,EC Number,AA_seq,ec_first_cat,ec_second_cat,ec_third_cat,ec_fourth_cat
0,Q65GK1,2.5.1.61,MRNIIVGSRRSKLAMTQTKWVIKKLEELNPDFTFEIKEIVTKGDRI...,2,2.5,2.5.1,2.5.1.61
1,P16616,2.5.1.61,MMRTIKVGSRRSKLAMTQTKWVIQKLKEINPSFAFEIKEIVTKGDR...,2,2.5,2.5.1,2.5.1.61
2,Q1LU25,2.5.1.61,MLNNILKIATRQSPLAIWQANYVRNQLLSFYPTLLIELVPIVTSGD...,2,2.5,2.5.1,2.5.1.61
3,Q7VRM4,2.5.1.61,MQAKILRIATRKSPLAICQACYVCNKLKHYHPHIQTELIPIITTGD...,2,2.5,2.5.1,2.5.1.61
4,Q491Z6,2.5.1.61,MKNKILKIATRKSQLAICQAQYVHNELKHYHPTLSIELMPIVTTGD...,2,2.5,2.5.1,2.5.1.61
...,...,...,...,...,...,...,...
258022,Q8R121,0.0.0.0,MRVASSLFLPVLLTEVWLVTSFNLSSHSPEASVHLESQDYENQTWE...,0,0.0,0.0.0,0.0.0.0
258023,Q3URR7,0.0.0.0,MLAEPVPDALEQEHPGAVKLEEDEVGEEDPRLAESRPRPEVAHQLF...,0,0.0,0.0.0,0.0.0.0
258024,P54479,0.0.0.0,MNVQEALNLLKENGYKYTNKREDMLQLFADSDRYLTAKNVLSALND...,0,0.0,0.0.0,0.0.0.0
258025,Q9VA00,0.0.0.0,MSASANLANVYAELMRRCGESYTITYGAPPTYLVSMVGAAEAGKKI...,0,0.0,0.0.0,0.0.0.0


In [8]:
first_cat = list(set(train['ec_first_cat']))
second_cat = list(set(train['ec_second_cat']))
third_cat = list(set(train['ec_third_cat']))
fourth_cat = list(set(train['ec_fourth_cat']))

In [9]:
first_cat, second_cat, third_cat, fourth_cat

(['6', '4', '2', '1', '0', '3', '5'],
 ['6.5',
  '1.3',
  '2.5',
  '2.8',
  '1.17',
  '6.2',
  '1.10',
  '1.4',
  '1.2',
  '4.99',
  '1.9',
  '1.18',
  '1.14',
  '3.5',
  '6.4',
  '1.15',
  '1.13',
  '4.3',
  '5.4',
  '3.3',
  '0.0',
  '2.7',
  '2.9',
  '3.11',
  '5.3',
  '2.1',
  '1.7',
  '4.1',
  '3.4',
  '5.99',
  '1.11',
  '1.97',
  '1.8',
  '6.1',
  '3.7',
  '3.1',
  '2.3',
  '6.3',
  '3.2',
  '2.2',
  '2.6',
  '5.2',
  '1.5',
  '5.1',
  '4.4',
  '1.6',
  '2.4',
  '1.1',
  '4.6',
  '3.6',
  '1.16',
  '4.2'],
 ['2.1.1',
  '2.4.99',
  '5.4.99',
  '1.3.5',
  '2.8.4',
  '6.3.4',
  '3.4.19',
  '2.8.1',
  '2.7.14',
  '3.5.1',
  '5.3.3',
  '1.18.6',
  '3.7.1',
  '1.4.4',
  '2.7.2',
  '1.6.5',
  '0.0.0',
  '4.2.99',
  '3.6.3',
  '2.7.4',
  '5.3.4',
  '6.3.3',
  '2.3.3',
  '2.4.2',
  '1.9.3',
  '5.99.1',
  '6.4.1',
  '1.5.1',
  '3.1.22',
  '4.1.3',
  '3.5.2',
  '2.7.9',
  '1.7.99',
  '2.3.1',
  '3.4.21',
  '3.1.21',
  '4.2.2',
  '5.1.3',
  '3.4.13',
  '3.4.14',
  '4.3.1',
  '4.4.1',
  '3.1

In [10]:
all_cat = first_cat + second_cat + third_cat + fourth_cat

In [15]:
all_cat.sort()
len(all_cat)

828

828 différentes classes 

In [16]:
label2idx = {ec:i for i,ec in enumerate(all_cat)}

In [17]:
label2idx

{'0': 0,
 '0.0': 1,
 '0.0.0': 2,
 '0.0.0.0': 3,
 '1': 4,
 '1.1': 5,
 '1.1.1': 6,
 '1.1.1.1': 7,
 '1.1.1.103': 8,
 '1.1.1.17': 9,
 '1.1.1.18': 10,
 '1.1.1.205': 11,
 '1.1.1.23': 12,
 '1.1.1.25': 13,
 '1.1.1.261': 14,
 '1.1.1.262': 15,
 '1.1.1.267': 16,
 '1.1.1.27': 17,
 '1.1.1.290': 18,
 '1.1.1.34': 19,
 '1.1.1.37': 20,
 '1.1.1.38': 21,
 '1.1.1.42': 22,
 '1.1.1.44': 23,
 '1.1.1.49': 24,
 '1.1.1.8': 25,
 '1.1.1.85': 26,
 '1.1.1.86': 27,
 '1.1.1.94': 28,
 '1.1.5': 29,
 '1.1.5.3': 30,
 '1.1.5.4': 31,
 '1.10': 32,
 '1.10.2': 33,
 '1.10.2.2': 34,
 '1.10.3': 35,
 '1.10.3.2': 36,
 '1.10.3.9': 37,
 '1.10.9': 38,
 '1.10.9.1': 39,
 '1.11': 40,
 '1.11.1': 41,
 '1.11.1.15': 42,
 '1.11.1.21': 43,
 '1.11.1.6': 44,
 '1.11.1.7': 45,
 '1.11.1.9': 46,
 '1.13': 47,
 '1.13.11': 48,
 '1.13.11.11': 49,
 '1.13.11.5': 50,
 '1.13.11.54': 51,
 '1.13.11.6': 52,
 '1.14': 53,
 '1.14.13': 54,
 '1.14.13.9': 55,
 '1.14.14': 56,
 '1.14.14.1': 57,
 '1.14.14.18': 58,
 '1.14.14.5': 59,
 '1.14.99': 60,
 '1.14.99.46': 61,
 

To do : 
- [] concaténer les colonnes de label en une liste puis mapper les éléments de la liste à leur indice de label donné par le dictionnaire label2idx
- [] dataloader et optimizer donnés par huggingface, même si le modèle est maison?
- [] faire un petit training loop pour tester si ça marche en se connectant sur le cluster chilien
