In [None]:
from transformers import Trainer, TrainingArguments
from dataclasses import dataclass
from torch import Tensor
from typing import List

class Config:
    def __init__(self):
        self.dataset: DatasetType = DatasetType.VIRAL
        self.model_name: str = 'gpt2'
        self.file_path: str = "./data/viral.1.1.genomic.fna"
        self.sequence_length: int = 400
        self.stride: int = 200
        self.split_ratio: float = 0.5
        self.substrings_per_seq: int = 20
        self.num_seqs: int = 10000
        self.sequences_shuffle: bool = True
        self.train_bs: int = 64
        self.val_bs: int = 128
        self.n_embed: int = 512
        self.n_layer: int = 4
        self.n_head: int = 16
        self.lr: float = 1e-4
        self.weight_decay: float = 0.01
        self.num_epochs: int = 200
        self.early_stopping_patience: int = 5
        self.weight_decay: float = 0.00
        self.warmup_steps: int = 10
        self.print_every: int = 20
        self.logging_steps: int = 20

@dataclass
class InputExample:
    """
    A single training/test example for the DNA dataset.
    """

    input_ids: Tensor
    labels: Tensor

class DNADataset(Dataset):
    def __init__(self, sequences: List[str], tokenizer: SequenceTokenizer):
        self.sequences = sequences
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, i):
        sequence = self.sequences[i]
        inputs = self.tokenizer.encode(sequence, return_tensors='pt')
        targets = inputs[1:].clone()
        inputs = inputs[:-1]
        return InputExample(input_ids=inputs, labels=targets)


def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    acc = (preds == labels).mean()
    return {'accuracy': acc}

def main(config: Config) -> None:
    train_seqs, val_seqs = load_datasets(config)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    tokenizer = SequenceTokenizer()
    train_dataset = DNADataset(train_seqs, tokenizer)
    val_dataset = DNADataset(val_seqs, tokenizer)

    wandb.init(project='GPT2_DNA', name='viruses', config=config)

    gpt2_config = GPT2Config(vocab_size=tokenizer.vocab_size,
                             n_positions=config.sequence_length,
                             n_ctx=config.sequence_length,
                             n_embd=config.n_embed,
                             n_layer=config.n_layer,
                             n_head=config.n_head)

    model = GPT2LMHeadModel(gpt2_config).to(device)

    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs for training.")
        model = torch.nn.DataParallel(model)
    optimizer = AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)

    training_args = TrainingArguments(
        output_dir='./results',          # output directory
        num_train_epochs=config.num_epochs,              # total # of training epochs
        per_device_train_batch_size=config.train_bs,  # batch size per device during training
        per_device_eval_batch_size=config.val_bs,   # batch size for evaluation
        warmup_steps=config.warmup_steps,                # number of warmup steps for learning rate scheduler
        learning_rate=config.lr,         # learning rate
        weight_decay=config.weight_decay,               # strength of weight decay
        logging_dir='./logs',            # directory for storing logs
        logging_steps=config.logging_steps,
    )
    trainer = Trainer(
        model=model,                         # the instantiated 🤗 Transformers model to be trained
        args=training_args,                  # training arguments, defined above
        optimizers=(optimizer, None),       # optimizer
        gradient_accumulation_steps=2,      # Modify as needed
        fp16=True,                          # if your GPU supports mixed precision
        train_dataset=train_dataset,         # training dataset
        eval_dataset=val_dataset,            # evaluation dataset
        compute_metrics=compute_metrics,     # the function to compute metrics 
    )

    trainer.train()



if __name__ == "__main__":
    config = Config()
    main(config)
