In [None]:
import csv
import os
import random
import sys
import time
import tokenize

from math import ceil
from typing import Any, Generator

In [None]:
import torch
from accelerate import Accelerator

In [None]:
from pathlib import Path

CURRENT_DIR = Path().resolve()

APPS_ROOT = CURRENT_DIR.parents[3]
PROJECT_ROOT = CURRENT_DIR.parents[4]

SHARED_DIR = APPS_ROOT / "shared"
STORAGE_DIR = PROJECT_ROOT / "storage"

In [None]:
from torch.nn.utils.rnn import pad_sequence
from torch.optim import AdamW
from torch.utils.data import DataLoader, IterableDataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.optimization import get_scheduler

In [None]:
import os
import random
from datetime import datetime

import numpy as np
import pandas as pd
import torch
from transformers import AutoTokenizer
from transformers.modeling_utils import PreTrainedModel
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
from transformers.tokenization_utils_base import PreTrainedTokenizerBase

def load_model(
	model_name: str
) -> tuple[PreTrainedModel, PreTrainedTokenizerBase]:
	checkpoint_path = os.path.join(STORAGE_DIR, "models", model_name)

	model = AutoModelForCausalLM.from_pretrained(checkpoint_path)
	tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)

	return model, tokenizer

def save_model(
	model_name: str,
	model: PreTrainedModel,
	tokenizer: PreTrainedTokenizerBase,
	history: dict | None
) -> None:
	output_path = os.path.join(STORAGE_DIR, "models", model_name)
	model.save_pretrained(output_path)
	tokenizer.save_pretrained(output_path)

	if history:
		df = pd.DataFrame(history)
		now = datetime.now().strftime("%Y%m%d-%H%M%S")
		df.to_csv(f"{output_path}/history-{now}.csv", index=False)

def set_seed(seed) -> None:
	random.seed(seed)
	np.random.seed(seed)
	torch.manual_seed(seed)
	torch.cuda.manual_seed(seed)

	torch.backends.cudnn.deterministic = True
	torch.backends.cudnn.benchmark = False
	torch.use_deterministic_algorithms(True)

	os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
	os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"

In [None]:
def process_sequence(sequence) -> str:
	return f"".join(f"[{nucl.upper()}]" for nucl in sequence)

def process_target(label) -> str:
	return f"[{label.upper()}]"

def promptfy(
	sequence: str,
	organism: str,
	hide_prob: float,
	gene: str | None,
	flank_before: str | None,
	flank_after: str | None,
) -> str:
	output = f"<|SEQUENCE|>{sequence}\n"

	if organism:
		if random.random() > hide_prob:
			output += f"<|ORGANISM|>{organism[:10]}\n"

	if gene:
		if random.random() > hide_prob:
			output += f"<|GENE|>{gene[:10]}\n"
	
	if flank_before:
		if random.random() > hide_prob:
			output += f"<|FLANK_BEFORE|>{flank_before}\n"
	
	if flank_after:
		if random.random() > hide_prob:
			output += f"<|FLANK_AFTER|>{flank_after}\n"
	
	output += "<|TARGET|>"

	return output

class DNADatasetFinetune(IterableDataset):
		def __init__(
			self,
			csv_path: str,
			tokenizer,
			dataset_total_length: int,
			feat_hide_prob: float,
			flanks_size: int = 25,
			sequence_max_length: int = 512,
		) -> None:
			self.csv_path = csv_path
			self.tokenizer = tokenizer
			self.max_length = sequence_max_length + flanks_size * 2 + 20
			self._length = dataset_total_length
			self.feat_hide_prob = feat_hide_prob

		def __len__(self):
			return self._length
		
		def __iter__(self) -> Generator[dict[str, torch.Tensor], Any, None]:
			with open(self.csv_path, newline='') as csvfile:
				reader = csv.DictReader(csvfile)
				for row in reader:
					sequence = process_sequence(row["sequence"])
					target = process_target(row["target"])
					organism = row["organism"]
					gene = row["gene"]
					flank_before = row["flankBefore"]
					flank_after = row["flankAfter"]

					prompt = promptfy(
						sequence=sequence,
						organism=organism,
						gene=gene,
						flank_before=flank_before,
						flank_after=flank_after,
						hide_prob=self.feat_hide_prob,
					)

					prompt_encoded = self.tokenizer(
						prompt,
						truncation=True,
						padding="max_length",
						max_length=self.max_length
					)

					input_ids = prompt_encoded["input_ids"]
					attention_mask = prompt_encoded["attention_mask"]

					yield {
						"input_ids": torch.tensor(input_ids),
						"attention_mask": torch.tensor(attention_mask),
						"labels": torch.tensor(self.tokenizer.encode(target))
					}

class FinetuneDataCollator:
	def __init__(self, tokenizer) -> None:
		self.tokenizer = tokenizer
		self.pad_token_id = tokenizer.pad_token_id
	
	def __call__(self, batch) -> dict[str, torch.Tensor]:
		input_ids = [example["input_ids"] for example in batch]
		attention_mask = [example["attention_mask"] for example in batch]
		labels = [example["labels"] for example in batch]

		input_ids_padded = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
		attention_mask_padded = pad_sequence(attention_mask, batch_first=True, padding_value=0)
		labels_padded = pad_sequence(labels, batch_first=True, padding_value=-100)

		return {
			"input_ids": input_ids_padded,
			"attention_mask": attention_mask_padded,
			"labels": labels_padded
		}

In [None]:
model_name = "gpt2-exin"
seed = 1234
uuid = "uuid"
data_length = 30000
batch_size = 1

In [None]:
set_seed(seed)

model, tokenizer = load_model(model_name)

data_path = os.path.join(SHARED_DIR, "temp", uuid)

dataset = DNADatasetFinetune(
	csv_path=data_path+".csv",
	tokenizer=tokenizer,
	dataset_total_length=data_length,
	feat_hide_prob=0.0
)
dataloader = DataLoader(
	dataset=dataset,
	batch_size=batch_size,
	collate_fn=FinetuneDataCollator(tokenizer)
)

In [None]:
model.to("cuda")

In [None]:
hit = 0

In [None]:
from tqdm import tqdm

In [None]:
with torch.no_grad():
	for batch in tqdm(dataloader):
		input_ids, attention_mask, label = [b.to(model.device) for b in batch.values()]
		
		responses = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=1)
		
		for response in responses:
			tokenizer.decode(response[0])
			if response[-1] == label[0][0]:
				hit += 1

In [None]:
hit/30000