# Definitions

In [6]:
import os
import re
import random
from enum import Enum
from typing import List, Tuple, DefaultDict, Set
from collections import defaultdict, Counter

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Config, GPT2LMHeadModel, AdamW
from tqdm.auto import tqdm
from Bio import SeqIO
import wandb


os.environ["WANDB_NOTEBOOK_NAME"] = 'GPT_dBG.ipynb'


class DatasetType(Enum):
    HMM = "HMM"
    VIRAL = "viral"


class Config:
    def __init__(self):
        self.dataset: DatasetType = DatasetType.HMM
        self.model_name: str = 'gpt2'
        self.file_path: str = "./data/viral.1.1.genomic.fna"
        self.sequence_length: int = 400
        self.stride: int = 200
        self.split_ratio: float = 0.5
        self.substrings_per_seq: int = 20
        self.num_seqs: int = 1000
        self.sparsity = 1.1
        self.num_hidden_states = 1000
        self.sequences_shuffle: bool = True
        self.train_bs: int = 128
        self.val_bs: int = 256
        self.n_embed: int = 512
        self.n_layer: int = 6
        self.n_head: int = 16
        self.lr: float = 1e-4
        self.weight_decay: float = 0.00
        self.num_epochs: int = 200
        self.early_stopping_patience: int = 5
        self.print_every: int = 20

class SequenceTokenizer:
    def __init__(self):
        self.alphabet = {'A', 'C', 'G', 'T'}
        self.token_to_idx = {char: i for i, char in enumerate(self.alphabet)}
        self.idx_to_token = {i: char for i, char in enumerate(self.alphabet)}
        self.vocab_size = len(self.token_to_idx)

    def encode(self, sequence: str, return_tensors: str = "pt") -> torch.Tensor:
        tokens = [self.token_to_idx[char] for char in sequence]
        if return_tensors == "pt":
            tokens = torch.tensor(tokens, dtype=torch.long)
        return tokens

    def decode(self, tokens: torch.Tensor) -> str:
        sequence = [self.idx_to_token[token.item()] for token in tokens]
        return ''.join(sequence)


class DNADataset(Dataset):
    def __init__(self, sequences: List[str], tokenizer: SequenceTokenizer):
        self.sequences = sequences
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, i):
        sequence = self.sequences[i]
        inputs = self.tokenizer.encode(sequence, return_tensors='pt')
        targets = inputs[1:].clone()
        inputs = inputs[:-1]
        return inputs, targets


def compute_char_probabilities(sequences: List[str]) -> None:
    counts = defaultdict(int)
    total_count = 0
    for seq in sequences:
        cd = Counter(seq)
        for c, count in cd.items():
            counts[c] += count
            total_count += count
    pcounts = {char: c / total_count for char, c in counts.items()}
    print(pcounts)


def read_fna(file_path: str, shuffle: bool = False) -> List[str]:
    sequences = []
    with open(file_path, "r") as f:
        for record in SeqIO.parse(f, "fasta"):
            sequences.append(str(record.seq))
    if shuffle:
        random.shuffle(sequences)
    return sequences


def generate_markov_chain(n: int, sparsity: float) -> torch.Tensor:
    transition_probs = torch.rand(n, n)
    for i in range(n):
        num_outgoing_states = torch.poisson(torch.tensor([float(sparsity)])).int().item()
        num_outgoing_states = min(n, max(1, num_outgoing_states))
        _, indices = torch.topk(transition_probs[i], num_outgoing_states)
        mask = torch.zeros_like(transition_probs[i]).scatter_(0, indices, 1).to(torch.bool)
        transition_probs[i] *= mask
    transition_probs = F.normalize(transition_probs, p=1, dim=1)
    return transition_probs


def draw_seq(transition_probs: torch.Tensor, sequence_length: int) -> str:
    NUCLEOTIDES = ['A', 'C', 'G', 'T']
    current_state = 0
    chain = [current_state]
    while len(chain) < sequence_length:
        next_state = torch.multinomial(transition_probs[current_state], num_samples=1)
        current_state = next_state.item()
        chain.append(current_state)
    seq = [NUCLEOTIDES[s % len(NUCLEOTIDES)] for s in chain]
    return ''.join(seq)


def generate_HMM_dataset(sequence_length: int, N: int, sparsity: float, num_hidden_states: int = None) -> Tuple[List[str], List[str]]:
    if num_hidden_states is None:
        num_hidden_states = sequence_length
    transition_matrix = generate_markov_chain(num_hidden_states, sparsity=sparsity)
    all_seqs = [[], []]
    for _ in range(2):
        seqs = []
        while len(seqs) < N:
            seq = draw_seq(transition_matrix, sequence_length)
            if len(seq) >= sequence_length:
                seqs.append(seq)
        all_seqs[_] = seqs
    return all_seqs


def generate_phylo_dataset(sequence_length: int, mutation_rate: float, N: int) -> List[str]:
    NUCLEOTIDES = ['A', 'C', 'G', 'T']
    parent_sequence = ''.join(random.choice(NUCLEOTIDES) for _ in range(sequence_length))
    dataset = [parent_sequence]
    for _ in range(N):
        mutated_sequence = ''
        for nucleotide in parent_sequence:
            if random.random() < mutation_rate:
                mutated_sequence += random.choice(NUCLEOTIDES)
            else:
                mutated_sequence += nucleotide
        dataset.append(mutated_sequence)
        parent_sequence = mutated_sequence
    return dataset


def construct_debruijn_graph(dataset: List[str], k: int) -> DefaultDict[str,Set[str]]:
    graph = defaultdict(set)
    for sequence in dataset:
        for i in range(len(sequence) - k):
            graph[sequence[i:i + k]].add(sequence[i + 1:i + k + 1])
    return graph


def extract_substrings(sequences: List[str], sequence_length: int, stride: int, substrings_per_seq: int) -> List[str]:
    substrings = []
    for sequence in sequences:
        for i in range(0, len(sequence) - sequence_length + 1, stride):
            if i // stride > substrings_per_seq:
                break
            seq = sequence[i:i + sequence_length]
            if bool(re.match("^[ACGT]+$", seq)):
                substrings.append(seq)
    return substrings


def set_random_seed(seed: int) -> None:
    torch.manual_seed(seed)
    random.seed(seed)


def train(model: torch.nn.Module, optimizer: torch.optim.Optimizer, train_loader: DataLoader, device: torch.device, description: str) -> float:
    model.train()
    running_loss = []
    bar = tqdm(train_loader, desc=description)
    for inputs, targets in bar:
        inputs = inputs.to(device)
        targets = targets.to(device)
        outputs = model(inputs, labels=targets)
        loss = outputs.loss.mean()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss.append(loss.item())
        bar.set_postfix({"Train Loss": loss.item()})
    return sum(running_loss) / len(running_loss)


def evaluate(model: torch.nn.Module, val_loader: DataLoader, device: torch.device, description: str) -> Tuple[float, float]:
    model.eval()
    total_loss = 0.0
    total_count = 0
    total_acc = 0.0
    bar = tqdm(val_loader, desc=description)
    with torch.no_grad():
        for inputs, targets in bar:
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = model(inputs, labels=targets)
            val_loss = outputs.loss.mean()
            predictions = outputs.logits.argmax(dim=-1)
            accuracy = (predictions == targets).cpu().float().mean().item()
            total_loss += val_loss.item() * inputs.size(0)
            total_count += inputs.size(0)
            total_acc += accuracy
            bar.set_postfix({"Val Loss": val_loss.item(), "Val Accuracy": accuracy})
    avg_val_loss = total_loss / total_count
    avg_accuracy = total_acc / len(val_loader)
    return avg_val_loss, avg_accuracy


def train_loop(model: torch.nn.Module, optimizer: torch.optim.Optimizer, train_loader: DataLoader, val_loader: DataLoader, device: torch.device, config: Config) -> None:
    best_val_loss = float('inf')
    num_epochs_no_improve = 0  # Number of epochs with no improvement in validation loss

    try:
        for epoch in range(config.num_epochs):
            train_loss = train(model, optimizer, train_loader, device, f"Epoch {epoch + 1}/{config.num_epochs} | Training")
            val_loss, val_acc = evaluate(model, val_loader, device, f"Epoch {epoch + 1}/{config.num_epochs} | Validation")
            samples = (epoch+1) * len(train_loader) * config.train_bs
            print(f"Epoch {epoch + 1}/{config.num_epochs} | Train Loss: {train_loss:.5f} | Val Loss: {val_loss:.5f} | Val Accuracy: {val_acc:.5f}")
            wandb.log({"epoch": epoch + 1, "samples": samples, "train_loss": train_loss, "val_loss": val_loss, "val_accuracy": val_acc})
            
            if (epoch + 1) % 5 == 0:
                # Save the model weights as an artifact every 10 epochs
                artifact = wandb.Artifact(f"model_weights", type='model')
                artifact.add_file('gpt2_dna.pth')
                wandb.log_artifact(artifact)

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                num_epochs_no_improve = 0
            else:
                num_epochs_no_improve += 1
                if num_epochs_no_improve >= config.early_stopping_patience:
                    print(f"Early stopping at epoch {epoch + 1}...")
                    break

    except KeyboardInterrupt:
        print("Interrupted by user")
    finally:
        torch.save(model.state_dict(), 'gpt2_dna.pth')
        wandb.finish()



def load_datasets(config: Config) -> Tuple[List[str], List[str]]:
    if config.dataset == DatasetType.HMM:
        train_seqs, val_seqs = generate_HMM_dataset(config.sequence_length, N=config.num_seqs, sparsity=config.sparsity,
                                                    num_hidden_states=config.num_hidden_states)
    elif config.dataset == DatasetType.VIRAL:
        set_random_seed(42)
        sequences = read_fna(file_path=config.file_path, shuffle=config.sequences_shuffle)
        sequences = sequences[:config.num_seqs]
        sub_seqs = extract_substrings(sequences, sequence_length=config.sequence_length, stride=config.stride,
                                      substrings_per_seq=config.substrings_per_seq)
        compute_char_probabilities(sub_seqs)
        train_size = int(len(sub_seqs) * (1 - config.split_ratio))
        train_seqs = sub_seqs[:train_size]
        val_seqs = sub_seqs[train_size:]
        train_seqs = train_seqs[:config.num_seqs]
        val_seqs = val_seqs[:config.num_seqs]
    return train_seqs, val_seqs



def main(config: Config) -> None:
    train_seqs, val_seqs = load_datasets(config)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    tokenizer = SequenceTokenizer()
    train_dataset = DNADataset(train_seqs, tokenizer)
    val_dataset = DNADataset(val_seqs, tokenizer)

    train_loader = DataLoader(train_dataset, batch_size=config.train_bs, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config.val_bs, shuffle=False)

    wandb.init(project='GPT2_DNA', name='viruses', config=config)

    gpt2_config = GPT2Config(vocab_size=tokenizer.vocab_size,
                             n_positions=config.sequence_length,
                             n_ctx=config.sequence_length,
                             n_embd=config.n_embed,
                             n_layer=config.n_layer,
                             n_head=config.n_head)

    model = GPT2LMHeadModel(gpt2_config).to(device)

    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs for training.")
        model = torch.nn.DataParallel(model)

    optimizer = AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)

    train_loop(model, optimizer, train_loader, val_loader, device, config)


if __name__ == "__main__":
    config = Config()
    main(config)


[34m[1mwandb[0m: Currently logged in as: [33mamirjoudaki[0m ([33msketch-bros[0m). Use [1m`wandb login --relogin`[0m to force relogin


Using 2 GPUs for training.




Epoch 1/200 | Training:   0%|          | 0/8 [00:00<?, ?it/s]



Epoch 1/200 | Validation:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 1/200 | Train Loss: 1.93486 | Val Loss: 1.44437 | Val Accuracy: 0.24102


Epoch 2/200 | Training:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 2/200 | Validation:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 2/200 | Train Loss: 1.42362 | Val Loss: 1.37949 | Val Accuracy: 0.26063


Epoch 3/200 | Training:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 3/200 | Validation:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 3/200 | Train Loss: 1.39000 | Val Loss: 1.36998 | Val Accuracy: 0.24952


Epoch 4/200 | Training:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 4/200 | Validation:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 4/200 | Train Loss: 1.37466 | Val Loss: 1.37012 | Val Accuracy: 0.24569


Epoch 5/200 | Training:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 5/200 | Validation:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 5/200 | Train Loss: 1.36605 | Val Loss: 1.36222 | Val Accuracy: 0.28293


Epoch 6/200 | Training:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 6/200 | Validation:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 6/200 | Train Loss: 1.35723 | Val Loss: 1.35101 | Val Accuracy: 0.26968


Epoch 7/200 | Training:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 7/200 | Validation:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 7/200 | Train Loss: 1.34799 | Val Loss: 1.34490 | Val Accuracy: 0.27447


Epoch 8/200 | Training:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 8/200 | Validation:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 8/200 | Train Loss: 1.34484 | Val Loss: 1.34542 | Val Accuracy: 0.27311


Epoch 9/200 | Training:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 9/200 | Validation:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 9/200 | Train Loss: 1.34265 | Val Loss: 1.34121 | Val Accuracy: 0.27308


Epoch 10/200 | Training:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 10/200 | Validation:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 10/200 | Train Loss: 1.34078 | Val Loss: 1.33936 | Val Accuracy: 0.27345


Epoch 11/200 | Training:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 11/200 | Validation:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 11/200 | Train Loss: 1.33837 | Val Loss: 1.33756 | Val Accuracy: 0.27326


Epoch 12/200 | Training:   0%|          | 0/8 [00:00<?, ?it/s]

Interrupted by user


0,1
epoch,▁▂▂▃▄▅▅▆▇▇█
samples,▁▂▂▃▄▅▅▆▇▇█
train_loss,█▂▂▁▁▁▁▁▁▁▁
val_accuracy,▁▄▂▂█▆▇▆▆▆▆
val_loss,█▄▃▃▃▂▁▂▁▁▁

0,1
epoch,11.0
samples,11264.0
train_loss,1.33837
val_accuracy,0.27326
val_loss,1.33756


# using trainer 

In [17]:
from transformers import Trainer, TrainingArguments
from dataclasses import dataclass
from torch import Tensor
from typing import List

class Config:
    def __init__(self):
        self.dataset: DatasetType = DatasetType.VIRAL
        self.model_name: str = 'gpt2'
        self.file_path: str = "./data/viral.1.1.genomic.fna"
        self.sequence_length: int = 400
        self.stride: int = 200
        self.split_ratio: float = 0.5
        self.substrings_per_seq: int = 20
        self.num_seqs: int = 10000
        self.sequences_shuffle: bool = True
        self.train_bs: int = 64
        self.val_bs: int = 128
        self.n_embed: int = 512
        self.n_layer: int = 4
        self.n_head: int = 16
        self.lr: float = 1e-4
        self.weight_decay: float = 0.01
        self.num_epochs: int = 200
        self.early_stopping_patience: int = 5
        self.weight_decay: float = 0.00
        self.warmup_steps: int = 10
        self.print_every: int = 20
        self.logging_steps: int = 20

@dataclass
class InputExample:
    """
    A single training/test example for the DNA dataset.
    """

    input_ids: Tensor
    labels: Tensor

class DNADataset(Dataset):
    def __init__(self, sequences: List[str], tokenizer: SequenceTokenizer):
        self.sequences = sequences
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, i):
        sequence = self.sequences[i]
        inputs = self.tokenizer.encode(sequence, return_tensors='pt')
        targets = inputs[1:].clone()
        inputs = inputs[:-1]
        return InputExample(input_ids=inputs, labels=targets)


def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    acc = (preds == labels).mean()
    return {'accuracy': acc}

def main(config: Config) -> None:
    train_seqs, val_seqs = load_datasets(config)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    tokenizer = SequenceTokenizer()
    train_dataset = DNADataset(train_seqs, tokenizer)
    val_dataset = DNADataset(val_seqs, tokenizer)

    wandb.init(project='GPT2_DNA', name='viruses', config=config)

    gpt2_config = GPT2Config(vocab_size=tokenizer.vocab_size,
                             n_positions=config.sequence_length,
                             n_ctx=config.sequence_length,
                             n_embd=config.n_embed,
                             n_layer=config.n_layer,
                             n_head=config.n_head)

    model = GPT2LMHeadModel(gpt2_config).to(device)

    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs for training.")
        model = torch.nn.DataParallel(model)
    optimizer = AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)

    training_args = TrainingArguments(
        output_dir='./results',          # output directory
        num_train_epochs=config.num_epochs,              # total # of training epochs
        per_device_train_batch_size=config.train_bs,  # batch size per device during training
        per_device_eval_batch_size=config.val_bs,   # batch size for evaluation
        warmup_steps=config.warmup_steps,                # number of warmup steps for learning rate scheduler
        learning_rate=config.lr,         # learning rate
        weight_decay=config.weight_decay,               # strength of weight decay
        logging_dir='./logs',            # directory for storing logs
        logging_steps=config.logging_steps,
    )
    trainer = Trainer(
        model=model,                         # the instantiated 🤗 Transformers model to be trained
        args=training_args,                  # training arguments, defined above
        optimizers=(optimizer, None),       # optimizer
        gradient_accumulation_steps=2,      # Modify as needed
        fp16=True,                          # if your GPU supports mixed precision
        train_dataset=train_dataset,         # training dataset
        eval_dataset=val_dataset,            # evaluation dataset
        compute_metrics=compute_metrics,     # the function to compute metrics 
    )

    trainer.train()



if __name__ == "__main__":
    config = Config()
    main(config)


{'A': 0.27363855788721336, 'G': 0.23881858764807198, 'C': 0.2252636735913389, 'T': 0.26227918087337576}


VBox(children=(Label(value='0.004 MB of 0.004 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
train/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/learning_rate,▁▄███▇▇▇▇▇▇▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▁▁▁
train/loss,██████▇▇▇▇▇▇▇▇▇▇▇▆▆▆▆▅▅▅▅▄▄▄▃▃▃▃▂▂▂▂▂▁▁▁

0,1
train/epoch,143.8
train/global_step,11360.0
train/learning_rate,3e-05
train/loss,1.0913


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016669005248695613, max=1.0…

Using 2 GPUs for training.




TypeError: __init__() got an unexpected keyword argument 'gradient_accumulation_steps'

NameError: name 'model' is not defined