In [1]:
! python --version

Python 3.10.14


In [2]:
%env CUDA_VISIBLE_DEVICES=1

env: CUDA_VISIBLE_DEVICES=1


In [3]:
import os

print(os.getenv('CUDA_VISIBLE_DEVICES'))

1


In [4]:
# ! pip install --force-reinstall "xformers<0.0.27"
# ! pip install matplotlib
# ! pip install plotly

In [5]:
import sys

sys.path.append('space-model')

In [6]:
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

# Installs Unsloth, Xformers (Flash Attention) and all other packages!
# !pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
# !pip install --no-deps "xformers<0.0.27" "trl<0.9.0" peft accelerate bitsandbytes

# Install Flash Attention 2 for softcapping support
import torch
import pandas as pd
import torch.nn.functional as F

# if torch.cuda.get_device_capability()[0] >= 8:
#     !pip install --no-deps packaging ninja einops "flash-attn>=2.6.3"

from sklearn.model_selection import train_test_split

from datasets import Dataset, DatasetDict

from unsloth import FastLanguageModel
from datasets import load_dataset

from transformers import TrainingArguments, Trainer, DataCollatorWithPadding
from unsloth import is_bfloat16_supported

from tqdm.auto import tqdm

from logger import get_logger

from space_model.model import SpaceModel
import space_model.loss as losses

  from .autonotebook import tqdm as notebook_tqdm


🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


In [7]:
# ! pip install --no-deps --upgrade "flash-attn>=2.6.3"

In [8]:
EVAL_BATCH_SIZE = 1
MAX_SEQ_LENGTH = 1024

In [9]:
log = get_logger(f'logs/gemma-2-2b', 'space-imdb-ft')

In [10]:
device_id = 0

In [11]:
class SpaceModelForSequenceClassification(torch.nn.Module):
    def __init__(self, base_model, n_embed=3, n_latent=3, n_concept_spaces=2, l1=1e-3, l2=1e-4, ce_w=1.0,
                 fine_tune=True):
        super().__init__()

        if fine_tune:
            for p in base_model.parameters():
                p.requires_grad_(False)

        self.device = base_model.device

        self.base_model = base_model

        self.space_model = SpaceModel(n_embed, n_latent, n_concept_spaces, output_concept_spaces=True)

        self.classifier = torch.nn.Linear(n_concept_spaces * n_latent, n_concept_spaces)

        self.l1 = l1
        self.l2 = l2
        self.ce_w = ce_w

    def to(self, device):
        self.device = device
        super().to(device)
        return self

    def to_inference(self):
        FastLanguageModel.for_inference(self.base_model)

    def to_training(self):
        FastLanguageModel.for_training(self.base_model)

    def forward(self, input_ids, attention_mask, labels=None):
        embed = self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=1,
            output_hidden_states=True
        ).hidden_states[-1].float()  # (B, max_seq_len, 2304)

        out = self.space_model(embed)  # SpaceModelOutput(logits=(B, n_concept_spaces * n_latent), ...)

        concept_hidden = out.logits

        logits = self.classifier(concept_hidden)

        loss = 0.0
        if labels is not None:
            loss = self.ce_w * F.cross_entropy(logits, labels)
            loss += self.l1 * losses.inter_space_loss(out.concept_spaces, labels) + self.l2 * losses.intra_space_loss(
                out.concept_spaces)

        return {"logits": logits, "loss": loss}

    def from_pretrained(self, path):
        self.space_model.load_state_dict(torch.load(f"{path}/space_model.pth"))
        self.classifier.load_state_dict(torch.load(f"{path}/classifier.pth"))
        return self

    def save_pretrained(self, path):
        self.base_model.save_pretrained(f"{path}/base")
        torch.save(self.space_model.state_dict(), f"{path}/space_model.pth")
        torch.save(self.classifier.state_dict(), f"{path}/classifier.pth")

In [12]:
def get_model_tokenizer(max_seq_length=1024, dtype=None, load_in_4bit=True, add_lora=False, load_from=None):
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name="unsloth/gemma-2-2b",
        # model_name="models/space-gemma-2-2b/base",
        max_seq_length=max_seq_length,
        dtype=dtype,
        load_in_4bit=load_in_4bit,
        device_map={'': 0},
        # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
    )

    if add_lora:
        model = FastLanguageModel.get_peft_model(
            model,
            r=8,  # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
            target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                            "gate_proj", "up_proj", "down_proj", ],
            lora_alpha=16,
            lora_dropout=0,  # Supports any, but = 0 is optimized
            bias="none",  # Supports any, but = "none" is optimized
            # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
            use_gradient_checkpointing="unsloth",  # True or "unsloth" for very long context
            random_state=3407,
            use_rslora=False,  # We support rank stabilized LoRA
            loftq_config=None,  # And LoftQ
        )

    tokenizer.truncation_side = 'left'
    tokenizer.padding_side = 'left'

    space_model = SpaceModelForSequenceClassification(
        model,
        n_embed=2304,
        n_latent=256,
        n_concept_spaces=2,
        l1=1e-3,
        l2=1e-7,
        ce_w=1.0,
        fine_tune=False
    )

    if load_from:
        space_model.from_pretrained(load_from)

    space_model.to(f"cuda:{device_id}")

    return space_model, tokenizer

In [13]:
def eval(f):
    def wrapper(model, *args, **kwargs):
        model.to_inference()
        with torch.no_grad():
            return f(model, *args, **kwargs)

    return wrapper


def train(f):
    def wrapper(model, *args, **kwargs):
        model.to_training()
        return f(model, *args, **kwargs)

    return wrapper


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [14]:
model, tokenizer = get_model_tokenizer(
    max_seq_length=MAX_SEQ_LENGTH,
    add_lora=False,
    # load_from="models/space-gemma-2-2b"
)

Unsloth: If you want to finetune Gemma 2, install flash-attn to make it faster!
To install flash-attn, do the below:

pip install --no-deps --upgrade "flash-attn>=2.6.3"
==((====))==  Unsloth 2024.8: Fast Gemma2 patching. Transformers = 4.43.4.
   \\   /|    GPU: NVIDIA RTX A5000. Max memory: 23.679 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.3.0+cu121. CUDA = 8.6. CUDA Toolkit = 12.1.
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.26.post1. FA2 = False]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth


In [15]:
count_parameters(model)

1180674

In [16]:
EOS_TOKEN = tokenizer.eos_token  # Must add EOS_TOKEN

HATE_TOKEN = 'Hate'
NORMAL_TOKEN = 'Normal'

FAKE_TOKEN = 'Fake'
TRUTH_TOKEN = 'Truth'

POS_TOKEN = 'Positive'
NEG_TOKEN = 'Negative'

In [17]:
def prepare_imdb(tokenizer, device, seed):
    # Load the IMDb dataset
    dataset = load_dataset("imdb")

    # Split the training set into training (80%) and validation (20%) sets
    train_testvalid = dataset['train'].train_test_split(test_size=0.2, seed=42)

    # Assign datasets
    train_dataset = train_testvalid['train']
    val_dataset = train_testvalid['test']
    test_dataset = dataset['test']

    dataset = DatasetDict({
        'train': train_dataset,
        'test': test_dataset,
        'val': val_dataset
    })

    PROMPT = '''### Text:
    {}
    
    ### Classification:
    {}'''

    def formatting_prompts_func(examples):
        inputs = examples["text"]
        outputs = [
            NEG_TOKEN if label == 0 else POS_TOKEN for label in examples["label"]
        ]
        texts = []
        prompts = []
        for input, output in zip(inputs, outputs):
            text = PROMPT.format(input, output) + EOS_TOKEN
            texts.append(text)
            prompts.append(PROMPT.format(input, ""))
        return {"ref": texts, 'prompt': prompts}

    dataset = dataset.map(formatting_prompts_func, batched=True)
    return dataset


def prepare_hateoffensive(tokenizer, device, seed):
    # 3 classes
    dataset = load_dataset("tdavidson/hate_speech_offensive")

    PROMPT = '''### Tweet:
    {}
    
    ### Classification:
    {}'''

    def formatting_prompts_func(examples):
        inputs = examples["tweet"]
        outputs = [
            HATE_TOKEN if label in [1, 0] else NORMAL_TOKEN for label in examples["class"]
        ]
        texts = []
        prompts = []
        for input, output in zip(inputs, outputs):
            text = PROMPT.format(input.strip(), output) + EOS_TOKEN
            texts.append(text)
            prompts.append(PROMPT.format(input, ""))
        return {"ref": texts, 'prompt': prompts}

    dataset = dataset.map(formatting_prompts_func, batched=True)

    # Split the training set into training (80%) and validation (20%) sets
    train_testvalid = dataset['train'].train_test_split(test_size=0.2, seed=42)
    test_valid = train_testvalid['test'].train_test_split(test_size=0.5, seed=42)

    # Assign datasets
    train_dataset = train_testvalid['train']
    val_dataset = test_valid['test']
    test_dataset = test_valid['train']

    dataset = DatasetDict({
        'train': train_dataset,
        'test': test_dataset,
        'val': val_dataset
    })
    return dataset.rename_columns({'class': 'label'})


def prepare_fake(tokenizer, device, seed):
    train_df = pd.read_csv('data/fake_train.csv', index_col=0)
    train_df['final'] = 'Title: ' + train_df['title'] + ' Text: ' + train_df['text']

    train_df = train_df[train_df['title'].notnull() & train_df['text'].notnull()]

    test_df = pd.read_csv('data/fake_test.csv', index_col=0)
    test_df['final'] = 'Title: ' + test_df['title'] + ' Text: ' + test_df['text']

    train_split, test_split = train_test_split(train_df, test_size=0.2, random_state=seed)
    test_split, val_split = train_test_split(test_split, test_size=0.5, random_state=seed)

    dataset = DatasetDict({
        'train': Dataset.from_pandas(train_split[['final', 'label']]),
        'test': Dataset.from_pandas(test_split[['final', 'label']]),
        'val': Dataset.from_pandas(val_split[['final', 'label']])
    })

    PROMPT = '''### Text:
    {}
    
    ### Classification:
    {}'''

    def formatting_prompts_func(examples):
        inputs = examples["final"]
        outputs = [
            FAKE_TOKEN if label == 0 else TRUTH_TOKEN for label in examples["label"]
        ]
        texts = []
        prompts = []
        for input, output in zip(inputs, outputs):
            text = PROMPT.format(input, output) + EOS_TOKEN
            texts.append(text)
            prompts.append(PROMPT.format(input, ""))
        return {"ref": texts, 'prompt': prompts}

    dataset = dataset.map(formatting_prompts_func, batched=True)
    return dataset

In [18]:
dataset = prepare_imdb(tokenizer, f"cuda:{device_id}", seed=3407)
dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'label', 'ref', 'prompt'],
        num_rows: 20000
    })
    test: Dataset({
        features: ['text', 'label', 'ref', 'prompt'],
        num_rows: 25000
    })
    val: Dataset({
        features: ['text', 'label', 'ref', 'prompt'],
        num_rows: 5000
    })
})

In [19]:
dataset['test']['ref'][0]

'### Text:\n    I love sci-fi and am willing to put up with a lot. Sci-fi movies/TV are usually underfunded, under-appreciated and misunderstood. I tried to like this, I really did, but it is to good TV sci-fi as Babylon 5 is to Star Trek (the original). Silly prosthetics, cheap cardboard sets, stilted dialogues, CG that doesn\'t match the background, and painfully one-dimensional characters cannot be overcome with a \'sci-fi\' setting. (I\'m sure there are those of you out there who think Babylon 5 is good sci-fi TV. It\'s not. It\'s clichéd and uninspiring.) While US viewers might like emotion and character development, sci-fi is a genre that does not take itself seriously (cf. Star Trek). It may treat important issues, yet not as a serious philosophy. It\'s really difficult to care about the characters here as they are not simply foolish, just missing a spark of life. Their actions and reactions are wooden and predictable, often painful to watch. The makers of Earth KNOW it\'s rubbi

In [20]:
model.base_model

Gemma2ForCausalLM(
  (model): Gemma2Model(
    (embed_tokens): Embedding(256000, 2304)
    (layers): ModuleList(
      (0-25): 26 x Gemma2DecoderLayer(
        (self_attn): Gemma2Attention(
          (q_proj): Linear4bit(in_features=2304, out_features=2048, bias=False)
          (k_proj): Linear4bit(in_features=2304, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=2304, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=2048, out_features=2304, bias=False)
          (rotary_emb): GemmaFixedRotaryEmbedding()
        )
        (mlp): Gemma2MLP(
          (gate_proj): Linear4bit(in_features=2304, out_features=9216, bias=False)
          (up_proj): Linear4bit(in_features=2304, out_features=9216, bias=False)
          (down_proj): Linear4bit(in_features=9216, out_features=2304, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): Gemma2RMSNorm()
        (post_attention_layernorm): Gemma2RMSNorm()
        (pre

In [21]:
# model.inference() # Enable native 2x faster inference
# inputs = tokenizer(
#     [
#         dataset['test']['prompt'][12],
#     ], return_tensors="pt").to(f"cuda:{device_id}")
# 
# outputs = model(**inputs)
# outputs

In [22]:
# outputs.logits

In [23]:
TOKENS = {
    'Positive': ...,
    'Negative': ...,

    'Hate': ...,
    'Normal': ...,
    'Fake': ...,
    'Truth': ...
}

In [24]:
for key, val in TOKENS.items():
    code = tokenizer.encode(key, add_special_tokens=False)
    print(f"{key}: {code}")
    TOKENS[key] = code[0]

Positive: [35202]
Negative: [39654]
Hate: [88060]
Normal: [15273]
Fake: [41181]
Truth: [55882]


In [25]:
def gpu_stats(device_id=0):
    #@title Show current memory stats
    gpu_stats = torch.cuda.get_device_properties(device_id)
    start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
    max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
    return {'gpu': gpu_stats.name, 'max_memory': max_memory, 'start_gpu_memory': start_gpu_memory}

In [26]:
@eval
def evaluate(model, tokenizer, dataset, batch_size, threshold=0.5):
    eval_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)

    preds = []
    labels = []
    for i, batch in tqdm(enumerate(eval_dataloader), total=len(eval_dataloader)):
        texts = batch["prompt"]

        inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=MAX_SEQ_LENGTH).to(
            model.device)

        outputs = model(**inputs)  # (B, 2)

        preds.extend(torch.argmax(outputs['logits'], dim=-1).cpu())
        labels.extend(batch["label"])

        if (i + 1) % 1000 == 0:
            log.warn(f'GPU Stats: {gpu_stats(device_id)}')

    val_acc = accuracy_score(labels, preds)
    val_f1 = f1_score(labels, preds, average='macro')
    val_precision = precision_score(labels, preds, average='macro')
    val_recall = recall_score(labels, preds, average='macro')

    log.info(f"Accuracy: {val_acc}, F1: {val_f1}, Precision: {val_precision}, Recall: {val_recall}")
    return val_acc, val_f1, val_precision, val_recall, preds, labels

In [27]:
# eval_acc, eval_f1, eval_precision, eval_recall, preds, labels = evaluate(model, tokenizer, dataset["test"],
# batch_size = EVAL_BATCH_SIZE)

In [28]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    accuracy = accuracy_score(labels, preds)
    f1 = f1_score(labels, preds, average='macro')
    precision = precision_score(labels, preds, average='macro')
    recall = recall_score(labels, preds, average='macro')
    return {
        'accuracy': accuracy,
        'f1': f1,
        'precision': precision,
        'recall': recall,
    }

In [29]:
@train
def training(model, tokenizer, dataset):
    trainer = Trainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=dataset['train'],
        eval_dataset=dataset['val'],
        data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
        args=TrainingArguments(
            per_device_train_batch_size=4,
            gradient_accumulation_steps=4,
            eval_steps=500,
            warmup_steps=500,
            num_train_epochs=2,  # Set this for 1 full training run.
            # max_steps=60,
            evaluation_strategy='steps',
            learning_rate=2e-5,
            fp16=not is_bfloat16_supported(),
            bf16=is_bfloat16_supported(),
            logging_steps=250,
            optim="adamw_8bit",
            weight_decay=0.01,
            lr_scheduler_type="linear",
            seed=3407,
            output_dir="models/gemma-2-2b",
            save_strategy='no',
        ),
        compute_metrics=compute_metrics
    )
    stats = trainer.train()
    return trainer, stats

In [30]:
def tokenize_function(examples):
    return tokenizer(examples["prompt"], padding="max_length", truncation=True, max_length=MAX_SEQ_LENGTH)


tokenized_datasets = dataset.map(tokenize_function, batched=True)

In [31]:
trainer, stats = training(model, tokenizer, tokenized_datasets)
stats

Step,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
500,0.6533,0.640074,0.6574,0.638718,0.69969,0.65795
1000,0.599,0.582474,0.737,0.73672,0.737863,0.736926
1500,0.5562,0.557908,0.7522,0.751661,0.754187,0.752092
2000,0.5382,0.548993,0.7546,0.752898,0.762239,0.754804
2500,0.5289,0.539992,0.7608,0.760796,0.760841,0.760814


TrainOutput(global_step=2500, training_loss=0.58227255859375, metrics={'train_runtime': 5650.5626, 'train_samples_per_second': 7.079, 'train_steps_per_second': 0.442, 'total_flos': 0.0, 'train_loss': 0.58227255859375, 'epoch': 2.0})

In [32]:
eval_acc, eval_f1, eval_precision, eval_recall, preds, labels = evaluate(model, tokenizer, dataset["test"], batch_size=EVAL_BATCH_SIZE)

  4%|▍         | 998/25000 [01:05<26:48, 14.92it/s]  GPU Stats: {'gpu': 'NVIDIA RTX A5000', 'max_memory': 23.679, 'start_gpu_memory': 8.662}
  8%|▊         | 1998/25000 [02:06<25:09, 15.24it/s]GPU Stats: {'gpu': 'NVIDIA RTX A5000', 'max_memory': 23.679, 'start_gpu_memory': 8.662}
 12%|█▏        | 2998/25000 [03:07<22:15, 16.48it/s]GPU Stats: {'gpu': 'NVIDIA RTX A5000', 'max_memory': 23.679, 'start_gpu_memory': 8.662}
 16%|█▌        | 3998/25000 [04:08<21:05, 16.60it/s]GPU Stats: {'gpu': 'NVIDIA RTX A5000', 'max_memory': 23.679, 'start_gpu_memory': 8.662}
 20%|█▉        | 4998/25000 [05:09<18:50, 17.69it/s]GPU Stats: {'gpu': 'NVIDIA RTX A5000', 'max_memory': 23.679, 'start_gpu_memory': 8.662}
 24%|██▍       | 5998/25000 [06:11<19:01, 16.65it/s]GPU Stats: {'gpu': 'NVIDIA RTX A5000', 'max_memory': 23.679, 'start_gpu_memory': 8.662}
 28%|██▊       | 6998/25000 [07:12<18:18, 16.39it/s]GPU Stats: {'gpu': 'NVIDIA RTX A5000', 'max_memory': 23.679, 'start_gpu_memory': 8.662}
 32%|███▏      | 79