## Training

In [1]:
from pytorch_lightning import Trainer
from datamodule import PAFDatamodule
from esm1b import ProteinClassifier
from pytorch_lightning.callbacks import ModelCheckpoint

# Initialize the data module and model
n_classes = 25
datamodule = PAFDatamodule("../datafiles", batch_size=32)
model = ProteinClassifier(n_classes=n_classes)

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',  
    dirpath='../datafiles', 
    filename='best-model-{epoch:02d}-{val_loss:.2f}',
    save_top_k=1,  
    mode='min', 
)

trainer = Trainer(
    max_epochs=100,
    callbacks=[checkpoint_callback]
)

trainer.fit(model=model, datamodule=datamodule)


The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'ESMTokenizer'. 
The class this function is called from is 'BertTokenizer'.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
You are using a model of type esm to instantiate a model of type bert. This is not supported for all configurations of models and can yield errors.
Some weights of the model checkpoint at facebook/esm-1b were not used when initializing BertModel: ['esm.encoder.layer.13.attention.LayerNorm.bias', 'esm.encoder.layer.2.attention.LayerNorm.bias', 'esm.encoder.layer.21.attention.self.value.weight', 'esm.encoder.layer.12.attention.LayerNorm.weight', 'esm.encoder.layer.0.attention.self.key.bias', 'esm.encoder.layer.12.output.dense.bias', 'esm.encoder.layer.17.LayerNorm.bias', 'esm.encoder.layer.2

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/77wu/anaconda3/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
  output, inverse_indices, counts = torch._unique2(
  tp = tp.sum(dim=0 if multidim_average == "global" else 1)
/Users/77wu/anaconda3/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

/Users/77wu/anaconda3/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


## Prediction

In [2]:
predictions = trainer.predict(model=model, dataloaders=datamodule.predict_dataloader())


/Users/77wu/anaconda3/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Predicting: |          | 0/? [00:00<?, ?it/s]

## Submission CSV

In [19]:
import pandas as pd

predicted_family_ids = []
for batch in predictions:
    for logits in batch:
        print("Logits Shape:", logits.shape) 
        predicted_index = logits.argmax()  
        predicted_label = datamodule.classes[predicted_index.cpu().item()]  
        predicted_family_ids.append(predicted_label) 

df = pd.read_csv("../datafiles/test_data.csv") 

# Save to CSV
submission_df = pd.DataFrame({
    'sequence_name': df['sequence_name'],
    'family_id': predicted_family_ids
})
submission_df.to_csv('submission.csv', index=False)


Logits Shape: torch.Size([25])
Logits Shape: torch.Size([25])
Logits Shape: torch.Size([25])
Logits Shape: torch.Size([25])
Logits Shape: torch.Size([25])
Logits Shape: torch.Size([25])
Logits Shape: torch.Size([25])
Logits Shape: torch.Size([25])
Logits Shape: torch.Size([25])
Logits Shape: torch.Size([25])
Logits Shape: torch.Size([25])
Logits Shape: torch.Size([25])
Logits Shape: torch.Size([25])
Logits Shape: torch.Size([25])
Logits Shape: torch.Size([25])
Logits Shape: torch.Size([25])
Logits Shape: torch.Size([25])
Logits Shape: torch.Size([25])
Logits Shape: torch.Size([25])
Logits Shape: torch.Size([25])
Logits Shape: torch.Size([25])
Logits Shape: torch.Size([25])
Logits Shape: torch.Size([25])
Logits Shape: torch.Size([25])
Logits Shape: torch.Size([25])
Logits Shape: torch.Size([25])
Logits Shape: torch.Size([25])
Logits Shape: torch.Size([25])
Logits Shape: torch.Size([25])
Logits Shape: torch.Size([25])
Logits Shape: torch.Size([25])
Logits Shape: torch.Size([25])
Logits S