In [32]:
import polars as pl
import pandas as pd
import numpy as np

from settings import gen_dataset

from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight

import sys
import os
import json

import sys
sys.path.append('../')


# Playing around .json formats

In [33]:
data_path = '/home/onyxia/work/HierarchicProtLM/data/'

In [34]:
train, validation, test = pd.read_json(data_path + 'ECPred40_train.json').drop('index',axis=1).rename(columns={"sequence": "AA_seq"}), pd.read_json(data_path + 'ECPred40_valid.json').rename(columns={"sequence": "AA_seq"}), pd.read_json(data_path + 'ECPred40_test.json').rename(columns={"sequence": "AA_seq"})

In [35]:
train

Unnamed: 0,Protein UniProt Acc.,EC Number,AA_seq
0,Q65GK1,2.5.1.61,MRNIIVGSRRSKLAMTQTKWVIKKLEELNPDFTFEIKEIVTKGDRI...
1,P16616,2.5.1.61,MMRTIKVGSRRSKLAMTQTKWVIQKLKEINPSFAFEIKEIVTKGDR...
2,Q1LU25,2.5.1.61,MLNNILKIATRQSPLAIWQANYVRNQLLSFYPTLLIELVPIVTSGD...
3,Q7VRM4,2.5.1.61,MQAKILRIATRKSPLAICQACYVCNKLKHYHPHIQTELIPIITTGD...
4,Q491Z6,2.5.1.61,MKNKILKIATRKSQLAICQAQYVHNELKHYHPTLSIELMPIVTTGD...
...,...,...,...
258022,Q8R121,0.0.0.0,MRVASSLFLPVLLTEVWLVTSFNLSSHSPEASVHLESQDYENQTWE...
258023,Q3URR7,0.0.0.0,MLAEPVPDALEQEHPGAVKLEEDEVGEEDPRLAESRPRPEVAHQLF...
258024,P54479,0.0.0.0,MNVQEALNLLKENGYKYTNKREDMLQLFADSDRYLTAKNVLSALND...
258025,Q9VA00,0.0.0.0,MSASANLANVYAELMRRCGESYTITYGAPPTYLVSMVGAAEAGKKI...


In [36]:
def split_and_create_columns(row):
    numbers = row['EC Number'].split('.')
    return pd.Series([numbers[0], '.'.join(numbers[:2]), '.'.join(numbers[:3]), '.'.join(numbers[:4])])

In [37]:
train[['ec_first_cat', 'ec_second_cat', 'ec_third_cat', 'ec_fourth_cat']] = train.apply(split_and_create_columns, axis=1)

In [38]:
# Specify the columns you want to concatenate
columns_to_concat = ['ec_first_cat', 'ec_second_cat', 'ec_third_cat', 'ec_fourth_cat']

# Create a new column with grouped numbers as lists
train['labels'] = train.apply(lambda x: list([x[col] for col in columns_to_concat]),axis=1)      


In [39]:
train

Unnamed: 0,Protein UniProt Acc.,EC Number,AA_seq,ec_first_cat,ec_second_cat,ec_third_cat,ec_fourth_cat,labels
0,Q65GK1,2.5.1.61,MRNIIVGSRRSKLAMTQTKWVIKKLEELNPDFTFEIKEIVTKGDRI...,2,2.5,2.5.1,2.5.1.61,"[2, 2.5, 2.5.1, 2.5.1.61]"
1,P16616,2.5.1.61,MMRTIKVGSRRSKLAMTQTKWVIQKLKEINPSFAFEIKEIVTKGDR...,2,2.5,2.5.1,2.5.1.61,"[2, 2.5, 2.5.1, 2.5.1.61]"
2,Q1LU25,2.5.1.61,MLNNILKIATRQSPLAIWQANYVRNQLLSFYPTLLIELVPIVTSGD...,2,2.5,2.5.1,2.5.1.61,"[2, 2.5, 2.5.1, 2.5.1.61]"
3,Q7VRM4,2.5.1.61,MQAKILRIATRKSPLAICQACYVCNKLKHYHPHIQTELIPIITTGD...,2,2.5,2.5.1,2.5.1.61,"[2, 2.5, 2.5.1, 2.5.1.61]"
4,Q491Z6,2.5.1.61,MKNKILKIATRKSQLAICQAQYVHNELKHYHPTLSIELMPIVTTGD...,2,2.5,2.5.1,2.5.1.61,"[2, 2.5, 2.5.1, 2.5.1.61]"
...,...,...,...,...,...,...,...,...
258022,Q8R121,0.0.0.0,MRVASSLFLPVLLTEVWLVTSFNLSSHSPEASVHLESQDYENQTWE...,0,0.0,0.0.0,0.0.0.0,"[0, 0.0, 0.0.0, 0.0.0.0]"
258023,Q3URR7,0.0.0.0,MLAEPVPDALEQEHPGAVKLEEDEVGEEDPRLAESRPRPEVAHQLF...,0,0.0,0.0.0,0.0.0.0,"[0, 0.0, 0.0.0, 0.0.0.0]"
258024,P54479,0.0.0.0,MNVQEALNLLKENGYKYTNKREDMLQLFADSDRYLTAKNVLSALND...,0,0.0,0.0.0,0.0.0.0,"[0, 0.0, 0.0.0, 0.0.0.0]"
258025,Q9VA00,0.0.0.0,MSASANLANVYAELMRRCGESYTITYGAPPTYLVSMVGAAEAGKKI...,0,0.0,0.0.0,0.0.0.0,"[0, 0.0, 0.0.0, 0.0.0.0]"


In [40]:
first_cat = list(set(train['ec_first_cat']))
second_cat = list(set(train['ec_second_cat']))
third_cat = list(set(train['ec_third_cat']))
fourth_cat = list(set(train['ec_fourth_cat']))

all_cat = first_cat + second_cat + third_cat + fourth_cat
all_cat.sort()
len(all_cat) #nb of labels
label2idx = {ec:i for i,ec in enumerate(all_cat)}

unsetted_labels = list(train['ec_first_cat']) + list(train['ec_second_cat']) + list(train['ec_third_cat']) + list(train['ec_fourth_cat'])

In [41]:
class_weights = compute_class_weight(class_weight='balanced', classes=np.sort(list(set(unsetted_labels))) , y=unsetted_labels)
len(class_weights)

#print(weights, len(weights))
# Write the data to text 
#np.savetxt('weights.txt', weights)
#print("Weights saved")

828

In [42]:
class_weights

array([4.05935267e-02, 4.05935267e-02, 4.05935267e-02, 4.05935267e-02,
       4.32839396e-02, 1.95481755e-01, 2.07770684e-01, 5.11276590e+00,
       5.00221961e+00, 5.97039115e+00, 1.68256478e+01, 1.07605887e+01,
       4.89635253e+00, 2.13720699e+00, 1.38120989e+01, 9.84479392e+00,
       2.11764446e+00, 3.25848813e+00, 6.08822782e+00, 2.43529113e+01,
       1.85082126e+00, 7.71175523e+01, 1.36089798e+01, 1.71372339e+01,
       1.36089798e+01, 1.88859312e+01, 3.22442728e+00, 1.74605779e+00,
       1.74605779e+00, 3.30503796e+00, 8.56861693e+00, 5.38029435e+00,
       2.62900747e+00, 1.71372339e+01, 1.71372339e+01, 3.76183182e+00,
       1.14248226e+01, 5.60854926e+00, 1.77963582e+01, 1.77963582e+01,
       1.02368432e+00, 1.02368432e+00, 3.79266651e+00, 2.96605971e+00,
       8.33703268e+00, 5.54138101e+00, 1.32201518e+01, 3.51867159e+00,
       3.51867159e+00, 9.95065191e+00, 1.10167932e+01, 2.15211774e+01,
       2.15211774e+01, 2.05190827e+00, 1.92793881e+01, 1.92793881e+01,
      

In [43]:
np.sort(list(set(unsetted_labels)))

array(['0', '0.0', '0.0.0', '0.0.0.0', '1', '1.1', '1.1.1', '1.1.1.1',
       '1.1.1.103', '1.1.1.17', '1.1.1.18', '1.1.1.205', '1.1.1.23',
       '1.1.1.25', '1.1.1.261', '1.1.1.262', '1.1.1.267', '1.1.1.27',
       '1.1.1.290', '1.1.1.34', '1.1.1.37', '1.1.1.38', '1.1.1.42',
       '1.1.1.44', '1.1.1.49', '1.1.1.8', '1.1.1.85', '1.1.1.86',
       '1.1.1.94', '1.1.5', '1.1.5.3', '1.1.5.4', '1.10', '1.10.2',
       '1.10.2.2', '1.10.3', '1.10.3.2', '1.10.3.9', '1.10.9', '1.10.9.1',
       '1.11', '1.11.1', '1.11.1.15', '1.11.1.21', '1.11.1.6', '1.11.1.7',
       '1.11.1.9', '1.13', '1.13.11', '1.13.11.11', '1.13.11.5',
       '1.13.11.54', '1.13.11.6', '1.14', '1.14.13', '1.14.13.9',
       '1.14.14', '1.14.14.1', '1.14.14.18', '1.14.14.5', '1.14.99',
       '1.14.99.46', '1.15', '1.15.1', '1.15.1.1', '1.16', '1.16.3',
       '1.16.3.1', '1.17', '1.17.1', '1.17.1.8', '1.17.4', '1.17.4.1',
       '1.17.7', '1.17.7.3', '1.17.7.4', '1.18', '1.18.1', '1.18.1.2',
       '1.18.6', '1.18.6.1'

In [44]:
first_cat, second_cat, third_cat, fourth_cat

(['3', '5', '2', '1', '4', '0', '6'],
 ['1.9',
  '3.2',
  '4.99',
  '3.4',
  '3.7',
  '6.4',
  '3.3',
  '5.2',
  '1.8',
  '1.11',
  '4.6',
  '1.4',
  '6.5',
  '1.16',
  '1.17',
  '1.13',
  '1.6',
  '4.2',
  '5.4',
  '5.3',
  '6.2',
  '6.3',
  '1.2',
  '3.5',
  '1.97',
  '5.1',
  '2.8',
  '2.9',
  '2.1',
  '6.1',
  '2.5',
  '5.99',
  '1.5',
  '1.7',
  '2.4',
  '1.14',
  '1.1',
  '1.15',
  '2.7',
  '2.2',
  '3.11',
  '1.18',
  '1.3',
  '4.1',
  '3.1',
  '4.3',
  '2.6',
  '2.3',
  '3.6',
  '4.4',
  '0.0',
  '1.10'],
 ['1.17.7',
  '4.2.2',
  '5.3.1',
  '2.2.1',
  '4.1.1',
  '2.9.1',
  '1.4.3',
  '2.8.1',
  '2.8.4',
  '1.97.1',
  '6.5.1',
  '2.7.10',
  '3.6.3',
  '3.1.3',
  '5.99.1',
  '3.5.4',
  '2.7.12',
  '1.15.1',
  '5.2.1',
  '1.4.4',
  '3.1.13',
  '1.10.2',
  '3.2.1',
  '3.5.2',
  '1.7.1',
  '3.4.19',
  '1.6.5',
  '1.3.7',
  '2.7.11',
  '2.4.2',
  '1.18.6',
  '4.1.99',
  '2.7.14',
  '5.4.99',
  '1.7.99',
  '3.2.2',
  '6.3.1',
  '4.2.1',
  '4.6.1',
  '1.14.14',
  '3.1.11',
  '4.3.1',
 

828 différentes classes 

In [45]:
label2idx

{'0': 0,
 '0.0': 1,
 '0.0.0': 2,
 '0.0.0.0': 3,
 '1': 4,
 '1.1': 5,
 '1.1.1': 6,
 '1.1.1.1': 7,
 '1.1.1.103': 8,
 '1.1.1.17': 9,
 '1.1.1.18': 10,
 '1.1.1.205': 11,
 '1.1.1.23': 12,
 '1.1.1.25': 13,
 '1.1.1.261': 14,
 '1.1.1.262': 15,
 '1.1.1.267': 16,
 '1.1.1.27': 17,
 '1.1.1.290': 18,
 '1.1.1.34': 19,
 '1.1.1.37': 20,
 '1.1.1.38': 21,
 '1.1.1.42': 22,
 '1.1.1.44': 23,
 '1.1.1.49': 24,
 '1.1.1.8': 25,
 '1.1.1.85': 26,
 '1.1.1.86': 27,
 '1.1.1.94': 28,
 '1.1.5': 29,
 '1.1.5.3': 30,
 '1.1.5.4': 31,
 '1.10': 32,
 '1.10.2': 33,
 '1.10.2.2': 34,
 '1.10.3': 35,
 '1.10.3.2': 36,
 '1.10.3.9': 37,
 '1.10.9': 38,
 '1.10.9.1': 39,
 '1.11': 40,
 '1.11.1': 41,
 '1.11.1.15': 42,
 '1.11.1.21': 43,
 '1.11.1.6': 44,
 '1.11.1.7': 45,
 '1.11.1.9': 46,
 '1.13': 47,
 '1.13.11': 48,
 '1.13.11.11': 49,
 '1.13.11.5': 50,
 '1.13.11.54': 51,
 '1.13.11.6': 52,
 '1.14': 53,
 '1.14.13': 54,
 '1.14.13.9': 55,
 '1.14.14': 56,
 '1.14.14.1': 57,
 '1.14.14.18': 58,
 '1.14.14.5': 59,
 '1.14.99': 60,
 '1.14.99.46': 61,
 

In [46]:
def translate_list(lst):
    return [label2idx[item] for item in lst]

train['labels'] = train['labels'].apply(translate_list)
train

Unnamed: 0,Protein UniProt Acc.,EC Number,AA_seq,ec_first_cat,ec_second_cat,ec_third_cat,ec_fourth_cat,labels
0,Q65GK1,2.5.1.61,MRNIIVGSRRSKLAMTQTKWVIKKLEELNPDFTFEIKEIVTKGDRI...,2,2.5,2.5.1,2.5.1.61,"[156, 282, 283, 294]"
1,P16616,2.5.1.61,MMRTIKVGSRRSKLAMTQTKWVIQKLKEINPSFAFEIKEIVTKGDR...,2,2.5,2.5.1,2.5.1.61,"[156, 282, 283, 294]"
2,Q1LU25,2.5.1.61,MLNNILKIATRQSPLAIWQANYVRNQLLSFYPTLLIELVPIVTSGD...,2,2.5,2.5.1,2.5.1.61,"[156, 282, 283, 294]"
3,Q7VRM4,2.5.1.61,MQAKILRIATRKSPLAICQACYVCNKLKHYHPHIQTELIPIITTGD...,2,2.5,2.5.1,2.5.1.61,"[156, 282, 283, 294]"
4,Q491Z6,2.5.1.61,MKNKILKIATRKSQLAICQAQYVHNELKHYHPTLSIELMPIVTTGD...,2,2.5,2.5.1,2.5.1.61,"[156, 282, 283, 294]"
...,...,...,...,...,...,...,...,...
258022,Q8R121,0.0.0.0,MRVASSLFLPVLLTEVWLVTSFNLSSHSPEASVHLESQDYENQTWE...,0,0.0,0.0.0,0.0.0.0,"[0, 1, 2, 3]"
258023,Q3URR7,0.0.0.0,MLAEPVPDALEQEHPGAVKLEEDEVGEEDPRLAESRPRPEVAHQLF...,0,0.0,0.0.0,0.0.0.0,"[0, 1, 2, 3]"
258024,P54479,0.0.0.0,MNVQEALNLLKENGYKYTNKREDMLQLFADSDRYLTAKNVLSALND...,0,0.0,0.0.0,0.0.0.0,"[0, 1, 2, 3]"
258025,Q9VA00,0.0.0.0,MSASANLANVYAELMRRCGESYTITYGAPPTYLVSMVGAAEAGKKI...,0,0.0,0.0.0,0.0.0.0,"[0, 1, 2, 3]"


In [47]:
train

Unnamed: 0,Protein UniProt Acc.,EC Number,AA_seq,ec_first_cat,ec_second_cat,ec_third_cat,ec_fourth_cat,labels
0,Q65GK1,2.5.1.61,MRNIIVGSRRSKLAMTQTKWVIKKLEELNPDFTFEIKEIVTKGDRI...,2,2.5,2.5.1,2.5.1.61,"[156, 282, 283, 294]"
1,P16616,2.5.1.61,MMRTIKVGSRRSKLAMTQTKWVIQKLKEINPSFAFEIKEIVTKGDR...,2,2.5,2.5.1,2.5.1.61,"[156, 282, 283, 294]"
2,Q1LU25,2.5.1.61,MLNNILKIATRQSPLAIWQANYVRNQLLSFYPTLLIELVPIVTSGD...,2,2.5,2.5.1,2.5.1.61,"[156, 282, 283, 294]"
3,Q7VRM4,2.5.1.61,MQAKILRIATRKSPLAICQACYVCNKLKHYHPHIQTELIPIITTGD...,2,2.5,2.5.1,2.5.1.61,"[156, 282, 283, 294]"
4,Q491Z6,2.5.1.61,MKNKILKIATRKSQLAICQAQYVHNELKHYHPTLSIELMPIVTTGD...,2,2.5,2.5.1,2.5.1.61,"[156, 282, 283, 294]"
...,...,...,...,...,...,...,...,...
258022,Q8R121,0.0.0.0,MRVASSLFLPVLLTEVWLVTSFNLSSHSPEASVHLESQDYENQTWE...,0,0.0,0.0.0,0.0.0.0,"[0, 1, 2, 3]"
258023,Q3URR7,0.0.0.0,MLAEPVPDALEQEHPGAVKLEEDEVGEEDPRLAESRPRPEVAHQLF...,0,0.0,0.0.0,0.0.0.0,"[0, 1, 2, 3]"
258024,P54479,0.0.0.0,MNVQEALNLLKENGYKYTNKREDMLQLFADSDRYLTAKNVLSALND...,0,0.0,0.0.0,0.0.0.0,"[0, 1, 2, 3]"
258025,Q9VA00,0.0.0.0,MSASANLANVYAELMRRCGESYTITYGAPPTYLVSMVGAAEAGKKI...,0,0.0,0.0.0,0.0.0.0,"[0, 1, 2, 3]"


To do : 
- [X] concaténer les colonnes de label en une liste puis mapper les éléments de la liste à leur indice de label donné par le dictionnaire label2idx
- [ ] dataloader et optimizer donnés par huggingface, même si le modèle est maison?

Modèle maison mais possible de récupérer les poids pour l'architecture qui nous intéresse uniquement. Possible aussi de reprendre le tokenizer (évidemment lié aux embeddings appris par le modèle). 

NB: Les poids juste encoders et de l'architecture pour la génération conditionnelle sont les mêmes. (explication?)

- [ ] faire un petit training loop pour tester si ça marche en se connectant sur le cluster chilien


### Loading model and tokenizer to use weights and tokenizer from the paper.

In [48]:
!pip install transformers

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




In [62]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from torch.utils.data import Dataset, DataLoader

tokenizer = AutoTokenizer.from_pretrained("ElnaggarLab/ankh-base")
#model = AutoModelForSeq2SeqLM.from_pretrained("ElnaggarLab/ankh-base")
#model2 = AutoModelForSeq2SeqLM.from_pretrained("ElnaggarLab/ankh-base-encoder", from_flax=True)

In [50]:
help(tokenizer)

Help on T5TokenizerFast in module transformers.models.t5.tokenization_t5_fast object:

class T5TokenizerFast(transformers.tokenization_utils_fast.PreTrainedTokenizerFast)
 |  T5TokenizerFast(vocab_file=None, tokenizer_file=None, eos_token='</s>', unk_token='<unk>', pad_token='<pad>', extra_ids=100, additional_special_tokens=None, add_prefix_space=None, **kwargs)
 |
 |  Construct a "fast" T5 tokenizer (backed by HuggingFace's *tokenizers* library). Based on
 |  [Unigram](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=unigram#models).
 |
 |  This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
 |  refer to this superclass for more information regarding those methods.
 |
 |  Args:
 |      vocab_file (`str`):
 |          [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
 |          contains the vocabulary necessary to instantiate a tokenizer.
 |     

In [52]:
model_state_dict = model.state_dict()
#model_state_dict2 = model2.state_dict()

In [53]:
model_state_dict

OrderedDict([('transformer_encoder.0.attention.self.query.weight',
              tensor([[-0.0138,  0.0185, -0.0236,  ...,  0.0123,  0.0152, -0.0358],
                      [-0.0247,  0.0075,  0.0147,  ...,  0.0188, -0.0329, -0.0300],
                      [ 0.0041, -0.0353, -0.0160,  ...,  0.0018,  0.0155,  0.0344],
                      ...,
                      [ 0.0124, -0.0111, -0.0268,  ..., -0.0359, -0.0320, -0.0027],
                      [ 0.0272, -0.0241,  0.0051,  ...,  0.0311,  0.0273,  0.0354],
                      [ 0.0257, -0.0159, -0.0212,  ...,  0.0124,  0.0107, -0.0104]])),
             ('transformer_encoder.0.attention.self.query.bias',
              tensor([ 5.4644e-03,  3.8362e-03, -1.0747e-02,  1.4817e-02,  1.9768e-02,
                      -2.9575e-02, -1.0857e-02,  2.3732e-02,  3.8784e-03,  3.4143e-02,
                      -1.2231e-02,  2.9221e-03, -2.1648e-03,  1.5713e-02, -2.7019e-03,
                       3.5505e-02,  1.7571e-02,  3.4366e-04, -2.5734e-03,

In [54]:
#model_state_dict2
#model2.parameters
dir(model)
model.load_state_dict
pytorch_total_params = sum(p.numel() for p in model.parameters())
pytorch_total_params

## Instantiating homemade model

In [None]:
from ankh import ConvBertForMultiLabelClassification
from transformers import Trainer, TrainingArguments
from datasets import Dataset
import torch

model = ConvBertForMultiLabelClassification(num_tokens=2, 
                                            input_dim=768, 
                                            nhead=4, 
                                            hidden_dim=384, 
                                            num_hidden_layers=1, 
                                            num_layers=1, 
                                            kernel_size=7, 
                                            dropout=0.2)

model


# Dataset creation
def create_dataset(tokenizer,seqs,labels, max_length):
    tokenized = tokenizer(seqs, max_length=max_length, padding=True, truncation=True) #changed padding to True by default
    dataset = dataset.add_column("labels", labels)
    dataset = Dataset.from_dict(tokenized)


    return dataset

train_set=create_dataset(tokenizer,list(train['AA_seq']),list(train['labels']),max_length=1050)

train_dataloader = DataLoader(train_set, batch_size=16, shuffle=False)

In [None]:
train_set['labels]

In [69]:
model.train()

ConvBertForMultiLabelClassification(
  (transformer_encoder): ModuleList(
    (0): ConvBertLayer(
      (attention): ConvBertAttention(
        (self): ConvBertSelfAttention(
          (query): Linear(in_features=768, out_features=384, bias=True)
          (key): Linear(in_features=768, out_features=384, bias=True)
          (value): Linear(in_features=768, out_features=384, bias=True)
          (key_conv_attn_layer): SeparableConv1D(
            (depthwise): Conv1d(768, 768, kernel_size=(7,), stride=(1,), padding=(3,), groups=768, bias=False)
            (pointwise): Conv1d(768, 384, kernel_size=(1,), stride=(1,), bias=False)
          )
          (conv_kernel_layer): Linear(in_features=384, out_features=14, bias=True)
          (conv_out_layer): Linear(in_features=768, out_features=384, bias=True)
          (unfold): Unfold(kernel_size=[7, 1], dilation=1, padding=[3, 0], stride=1)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (output): ConvBertSelfOutput(
     

In [90]:
help(DataLoader)

Help on class DataLoader in module torch.utils.data.dataloader:

class DataLoader(typing.Generic)
 |  DataLoader(dataset: torch.utils.data.dataset.Dataset[+T_co], batch_size: Optional[int] = 1, shuffle: Optional[bool] = None, sampler: Union[torch.utils.data.sampler.Sampler, Iterable, NoneType] = None, batch_sampler: Union[torch.utils.data.sampler.Sampler[List], Iterable[List], NoneType] = None, num_workers: int = 0, collate_fn: Optional[Callable[[List[~T]], Any]] = None, pin_memory: bool = False, drop_last: bool = False, timeout: float = 0, worker_init_fn: Optional[Callable[[int], NoneType]] = None, multiprocessing_context=None, generator=None, *, prefetch_factor: Optional[int] = None, persistent_workers: bool = False, pin_memory_device: str = '')
 |
 |  Data loader combines a dataset and a sampler, and provides an iterable over the given dataset.
 |
 |  The :class:`~torch.utils.data.DataLoader` supports both map-style and
 |  iterable-style datasets with single- or multi-process loadi

In [87]:
train_dataloader.__getattribute__(0)

TypeError: attribute name must be string, not 'int'

In [76]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

# Train loop
num_epochs = 1
for epoch in range(num_epochs):
    model.train()
    for inputs, labels, _ in train_dataloader:  # You need to replace train_dataloader with your actual data loader
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

TypeError: linear(): argument 'input' (position 1) must be Tensor, not str

In [93]:
for inputs, labels, _ in train_dataloader: 
    print(inputs, labels) # You need to replace train_dataloader with your actual data loader
    break


input_ids attention_mask


In [25]:
training_args = TrainingArguments(
    output_dir="./output",
    num_train_epochs=3,
    per_device_train_batch_size=8,
    logging_dir="./logs",
)

trainer = Trainer(
    model=model,
    args=training_args,
    # Add your dataset and data collator here
    # For example:
     train_dataset=train_set,
    # data_collator=data_collator,
)

# Start training
trainer.train()

TypeError: ConvBertForMultiLabelClassification.forward() missing 1 required positional argument: 'embed'

In [30]:
train_set[0]['input_ids']

[19,
 8,
 17,
 12,
 12,
 6,
 5,
 7,
 8,
 8,
 7,
 14,
 4,
 3,
 19,
 11,
 16,
 11,
 14,
 21,
 6,
 12,
 14,
 14,
 4,
 9,
 9,
 4,
 17,
 13,
 10,
 15,
 11,
 15,
 9,
 12,
 14,
 9,
 12,
 6,
 11,
 14,
 5,
 10,
 8,
 12,
 4,
 10,
 6,
 11,
 4,
 7,
 14,
 6,
 5,
 5,
 14,
 5,
 4,
 15,
 6,
 14,
 9,
 12,
 9,
 16,
 3,
 19,
 4,
 7,
 5,
 9,
 12,
 10,
 19,
 3,
 6,
 20,
 7,
 19,
 14,
 10,
 19,
 13,
 7,
 6,
 4,
 13,
 9,
 5,
 4,
 19,
 6,
 5,
 22,
 12,
 13,
 14,
 8,
 9,
 10,
 3,
 8,
 10,
 6,
 4,
 12,
 7,
 14,
 5,
 8,
 16,
 14,
 4,
 3,
 10,
 4,
 14,
 16,
 5,
 3,
 6,
 6,
 5,
 11,
 7,
 7,
 4,
 8,
 8,
 7,
 3,
 16,
 4,
 4,
 16,
 19,
 8,
 13,
 10,
 4,
 9,
 12,
 14,
 21,
 12,
 8,
 5,
 17,
 12,
 10,
 11,
 8,
 4,
 14,
 14,
 4,
 9,
 11,
 9,
 10,
 18,
 10,
 3,
 12,
 12,
 4,
 3,
 3,
 3,
 5,
 4,
 7,
 8,
 19,
 5,
 21,
 14,
 10,
 10,
 6,
 6,
 11,
 9,
 15,
 4,
 10,
 13,
 9,
 7,
 22,
 4,
 13,
 3,
 6,
 5,
 16,
 5,
 3,
 4,
 3,
 12,
 9,
 22,
 8,
 5,
 7,
 10,
 9,
 9,
 4,
 4,
 7,
 4,
 15,
 7,
 8,
 4,
 17,
 10,
 8,
 18,
 11,
 16,
 

In [31]:
model(train_set[0]['input_ids'])

TypeError: linear(): argument 'input' (position 1) must be Tensor, not list

In [None]:
train