<a href="https://colab.research.google.com/github/Daniele-Cangi/AIpowerCoin-Public-Documentation-Collaboration-Portal/blob/main/NEXUS_TRANSFER_64_ANCHORS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# 📝 NEXUS_TRANSFER_64_ANCHORS.ipynb
# -------------------------------------------------------------
# 🚀 NEXUS Transfer Learning - 64 Anchors Edition
# ✅ Checkpoint ogni 2000 steps
# ✅ 64 NSA anchors (vs 32 original)
# ✅ Analisi anchor specialization post-training
# -------------------------------------------------------------
# STEP 0: MOUNT GOOGLE DRIVE & SETUP CHECKPOINT DIR
from google.colab import drive
import os

drive.mount('/content/drive')
CHECKPOINT_DIR = '/content/drive/MyDrive/nexus_checkpoints_64anchors'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
print(f"✅ Google Drive montato! -> {CHECKPOINT_DIR}")
print("⚠️ I checkpoint sopravvivono ai disconnect di Colab.")

Mounted at /content/drive
✅ Google Drive montato! -> /content/drive/MyDrive/nexus_checkpoints_64anchors
⚠️ I checkpoint sopravvivono ai disconnect di Colab.


In [2]:
# STEP 1: SETUP AMBIENTE (Re-eseguibile)
!pip -q install torch transformers datasets tqdm sentencepiece accelerate
import torch, platform
print(f"Torch: {torch.__version__} | Python: {platform.python_version()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)} | VRAM: {torch.cuda.get_device_properties(0).total_memory/1024**3:.1f} GB")
else:
    print("⚠️ No CUDA detected")

Torch: 2.8.0+cu126 | Python: 3.12.11
GPU: NVIDIA A100-SXM4-40GB | VRAM: 39.6 GB


In [3]:
# STEP 2: DOWNLOAD MICROSOFT ALLNLI DATASET
from datasets import load_dataset
import json
from tqdm import tqdm

print("📥 Downloading Microsoft AllNLI dataset...")

# Download AllNLI (combines SNLI + MultiNLI)
try:
    # Load SNLI dataset
    snli = load_dataset("snli", split="train")
    print(f"✅ SNLI loaded: {len(snli):,} examples")

    # Load MultiNLI dataset
    mnli = load_dataset("multi_nli", split="train")
    print(f"✅ MultiNLI loaded: {len(mnli):,} examples")

    # Filter out examples with label -1 (unlabeled) and combine
    all_data = []

    # Process SNLI
    for item in tqdm(snli, desc="Processing SNLI"):
        if item['label'] != -1:  # Skip unlabeled
            all_data.append({
                'sentence1': item['premise'],
                'sentence2': item['hypothesis'],
                'label': ['entailment', 'neutral', 'contradiction'][item['label']]
            })

    # Process MultiNLI
    for item in tqdm(mnli, desc="Processing MultiNLI"):
        if item['label'] != -1:  # Skip unlabeled
            all_data.append({
                'sentence1': item['premise'],
                'sentence2': item['hypothesis'],
                'label': ['entailment', 'neutral', 'contradiction'][item['label']]
            })

    # Save to JSONL for PairDataset
    with open('/content/allnli_pairs.jsonl', 'w') as f:
        for item in all_data:
            f.write(json.dumps(item) + '\n')

    print(f"✅ AllNLI dataset ready: {len(all_data):,} examples")
    print(f"📁 Saved to: /content/allnli_pairs.jsonl")

except Exception as e:
    print(f"❌ Error downloading dataset: {e}")
    print("📝 Creating minimal demo dataset for testing...")

    # Fallback demo data with correct format
    demo_data = [
        {"sentence1": "A person on a horse jumps over a broken down airplane.", "sentence2": "A person is at a diner, ordering an omelette.", "label": "contradiction"},
        {"sentence1": "Children smiling and waving at camera", "sentence2": "They are smiling at a camera", "label": "entailment"},
        {"sentence1": "A black race car starts up in front of a crowd of people.", "sentence2": "A man is driving down a lonely road.", "label": "contradiction"},
        {"sentence1": "Two women are embracing while holding to go packages.", "sentence2": "Two women are holding packages.", "label": "entailment"},
        {"sentence1": "A soccer game with multiple males playing.", "sentence2": "Some men are playing a sport.", "label": "entailment"}
    ] * 1000  # Repeat for testing

    with open('/content/allnli_pairs.jsonl', 'w') as f:
        for item in demo_data:
            f.write(json.dumps(item) + '\n')

    print(f"✅ Demo dataset created: {len(demo_data):,} examples")

📥 Downloading Microsoft AllNLI dataset...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

plain_text/test-00000-of-00001.parquet:   0%|          | 0.00/412k [00:00<?, ?B/s]

plain_text/validation-00000-of-00001.par(…):   0%|          | 0.00/413k [00:00<?, ?B/s]

plain_text/train-00000-of-00001.parquet:   0%|          | 0.00/19.6M [00:00<?, ?B/s]

Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/10000 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/550152 [00:00<?, ? examples/s]

✅ SNLI loaded: 550,152 examples


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/214M [00:00<?, ?B/s]

data/validation_matched-00000-of-00001.p(…):   0%|          | 0.00/4.94M [00:00<?, ?B/s]

data/validation_mismatched-00000-of-0000(…):   0%|          | 0.00/5.10M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/392702 [00:00<?, ? examples/s]

Generating validation_matched split:   0%|          | 0/9815 [00:00<?, ? examples/s]

Generating validation_mismatched split:   0%|          | 0/9832 [00:00<?, ? examples/s]

✅ MultiNLI loaded: 392,702 examples


Processing SNLI: 100%|██████████| 550152/550152 [00:17<00:00, 30688.44it/s]
Processing MultiNLI: 100%|██████████| 392702/392702 [00:35<00:00, 11215.43it/s]


✅ AllNLI dataset ready: 942,069 examples
📁 Saved to: /content/allnli_pairs.jsonl


In [4]:
# STEP 3: MODEL + LOSS DEFINITIONS
import torch, torch.nn as nn, torch.nn.functional as F, numpy as np, random
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModel, AutoTokenizer
import math, json

class NexusTransfer(nn.Module):
    def __init__(self, base_model_name='microsoft/mpnet-base', num_topics=3, hash_dim=256, nsa_anchors=64):
        super().__init__()
        self.base = AutoModel.from_pretrained(base_model_name)
        d = self.base.config.hidden_size
        self.topic_head = nn.Linear(d, num_topics)
        self.hash_proj = nn.Linear(d, hash_dim)
        self.anchors = nn.Parameter(torch.randn(nsa_anchors, d))
        print(f"✅ Base {base_model_name} | Params {sum(p.numel() for p in self.parameters()):,}")
    def forward(self, input_ids, attention_mask=None):
        out = self.base(input_ids=input_ids, attention_mask=attention_mask)
        h = out.last_hidden_state
        if attention_mask is not None:
            mask = attention_mask.unsqueeze(-1).float()
            h = (h * mask).sum(1) / mask.sum(1).clamp(min=1e-9)
        else:
            h = h.mean(1)
        emb = F.normalize(h, dim=-1)
        topic_logits = self.topic_head(emb)
        hash_emb = torch.tanh(self.hash_proj(emb))
        nsa_scores = emb @ self.anchors.T
        return emb, topic_logits, hash_emb, nsa_scores

class PairDataset(Dataset):
    def __init__(self, path):
        with open(path,'r',encoding='utf-8') as f:
            self.data = [json.loads(l) for l in f]
        # Map labels to integers for topic classification
        self.label_map = {'entailment': 0, 'neutral': 1, 'contradiction': 2}
    def __len__(self): return len(self.data)
    def __getitem__(self,i):
        d = self.data[i]
        return d['sentence1'], d['sentence2'], self.label_map[d['label']]

def info_nce(a,b,t=0.07):
    sim = a @ b.T / t
    labels = torch.arange(sim.size(0), device=sim.device)
    return (F.cross_entropy(sim, labels) + F.cross_entropy(sim.T, labels)) / 2

def matryoshka_loss(a,b,dims,t=0.07):
    return sum(info_nce(F.normalize(a[:,:d],dim=-1), F.normalize(b[:,:d],dim=-1), t) for d in dims)/len(dims)

def hash_loss(h):
    return ((h - h.sign())**2).mean()

def spec_loss(scores_a, scores_b):
    return 1.0 - (F.normalize(scores_a,dim=-1)*F.normalize(scores_b,dim=-1)).sum(-1).mean()

In [5]:
# STEP 4: INIT MODEL + DATA + OPTIMIZER
import torch, time
from torch.cuda.amp import GradScaler, autocast

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Device:', device)

tokenizer = AutoTokenizer.from_pretrained('microsoft/mpnet-base')
model = NexusTransfer(nsa_anchors=64).to(device)
try:
    model = torch.compile(model, mode='max-autotune')
    print('torch.compile enabled')
except Exception as e:
    print('compile skipped:', e)

ds = PairDataset('/content/allnli_pairs.jsonl')
print(f'Dataset size: {len(ds):,}')

BATCH=64; MAX_LEN=128; DIMS=[768,512,384,256,128]; TEMP=0.07

def collate(batch):
    A,B,T = zip(*batch)
    enc_a = tokenizer(list(A), padding=True, truncation=True, max_length=MAX_LEN, return_tensors='pt')
    enc_b = tokenizer(list(B), padding=True, truncation=True, max_length=MAX_LEN, return_tensors='pt')
    return enc_a['input_ids'], enc_a['attention_mask'], enc_b['input_ids'], enc_b['attention_mask'], torch.tensor(T)

dl = DataLoader(ds, batch_size=BATCH, shuffle=True, drop_last=True, collate_fn=collate)

base_params = list(model._orig_mod.base.parameters()) if hasattr(model,'_orig_mod') else list(model.base.parameters())
heads = [p for n,p in model.named_parameters() if 'base.' not in n]
opt = torch.optim.AdamW([
    {'params': base_params, 'lr': 2e-5},
    {'params': heads, 'lr': 1e-4}
], weight_decay=0.01, fused=True)
scaler = GradScaler()

start_step = 0
latest_ptr = f"{CHECKPOINT_DIR}/latest_checkpoint.txt"
if os.path.exists(latest_ptr):
    try:
        with open(latest_ptr) as f: ckpt_path = f.read().strip()
        ckpt = torch.load(ckpt_path)
        state = ckpt['model']
        if any(k.startswith('_orig_mod.') for k in state):
            state = {k.replace('_orig_mod.',''):v for k,v in state.items()}
            (model._orig_mod if hasattr(model,'_orig_mod') else model).load_state_dict(state)
        else:
            (model._orig_mod if hasattr(model,'_orig_mod') else model).load_state_dict(state)
        opt.load_state_dict(ckpt['opt'])
        start_step = ckpt['step']
        print('Resumed from', ckpt_path, 'step', start_step)
    except Exception as e:
        print('Resume failed:', e)

Device: cuda


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/493 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/532M [00:00<?, ?B/s]

Some weights of MPNetModel were not initialized from the model checkpoint at microsoft/mpnet-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


✅ Base microsoft/mpnet-base | Params 109,734,787
torch.compile enabled
Dataset size: 942,069


  scaler = GradScaler()


Resumed from /content/drive/MyDrive/nexus_checkpoints_64anchors/ckpt_30000.pt step 30000


In [6]:
# STEP 5: TRAINING LOOP (checkpoint ogni 2000 step)
TOTAL_STEPS = 30000
CKPT_INTERVAL = 2000
step = start_step
model.train(); opt.zero_grad(set_to_none=True)
start_time = time.time()
for epoch in range(999999):
    for a_ids, a_mask, b_ids, b_mask, topics in dl:
        a_ids=a_ids.to(device); a_mask=a_mask.to(device); b_ids=b_ids.to(device); b_mask=b_mask.to(device); topics=topics.to(device)
        with autocast():
            emb_a, logit_a, hash_a, spec_a = model(a_ids, a_mask)
            emb_b, logit_b, hash_b, spec_b = model(b_ids, b_mask)
            L_main = matryoshka_loss(emb_a, emb_b, DIMS, t=TEMP)
            L_topic = 0.03 * F.cross_entropy(logit_a, topics)
            L_hash = 0.01 * hash_loss(hash_a)
            L_spec = 0.03 * spec_loss(spec_a, spec_b)
            loss = L_main + L_topic + L_hash + L_spec
        scaler.scale(loss).backward()
        scaler.unscale_(opt)
        torch.nn.utils.clip_grad_norm_(model.parameters(),1.0)
        scaler.step(opt); scaler.update(); opt.zero_grad(set_to_none=True)
        step += 1
        if step % 100 == 0:
            elapsed = (time.time()-start_time)/3600
            speed = (step-start_step)/elapsed if elapsed>0 else 0
            print(f"Step {step:,} | Loss {loss.item():.4f} | M {L_main.item():.3f} T {L_topic.item():.3f} H {L_hash.item():.3f} S {L_spec.item():.3f} | {speed:.0f} st/h")
        if step % CKPT_INTERVAL == 0:
            raw = model._orig_mod if hasattr(model,'_orig_mod') else model
            ckpt_path = f"{CHECKPOINT_DIR}/ckpt_{step}.pt"
            torch.save({'step': step, 'model': raw.state_dict(), 'opt': opt.state_dict()}, ckpt_path)
            with open(f"{CHECKPOINT_DIR}/latest_checkpoint.txt","w") as f: f.write(ckpt_path)
            print('💾 Saved', ckpt_path)
        if step >= TOTAL_STEPS:
            break
    if step >= TOTAL_STEPS:
        break
print('✅ Training complete. Total steps:', step)

  with autocast():
AUTOTUNE addmm(4352x3072, 4352x768, 768x3072)
strides: [0, 1], [768, 1], [1, 768]
dtypes: torch.float16, torch.float16, torch.float16
  triton_mm_130 0.1085 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_129 0.1167 ms 93.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_122 0.1208 ms 89.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_126 0.1208 ms 89.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_124 0.1239 ms 87.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K

✅ Training complete. Total steps: 30001


In [7]:
# STEP 6: EVALUATION + SIMILARITY TESTS
import numpy as np, os, torch
latest_ptr = f"{CHECKPOINT_DIR}/latest_checkpoint.txt"
if not os.path.exists(latest_ptr):
    raise FileNotFoundError('No checkpoint to evaluate')
with open(latest_ptr) as f: ckpt_path = f.read().strip()
ckpt = torch.load(ckpt_path, map_location='cpu')
raw = NexusTransfer(nsa_anchors=64).to(device)
raw.load_state_dict(ckpt['model'])
raw.eval()

def embed(texts, batch_size=32):
    all_vecs = []
    for i in range(0,len(texts), batch_size):
        batch = texts[i:i+batch_size]
        enc = tokenizer(batch, padding=True, truncation=True, max_length=MAX_LEN, return_tensors='pt').to(device)
        with torch.no_grad():
            emb, *_ = raw(enc['input_ids'], enc.get('attention_mask'))
        all_vecs.append(emb.cpu())
    return torch.cat(all_vecs).numpy()

tests = [
 ("Bitcoin reaches new all-time high","BTC price surges",True),
 ("AI is advancing rapidly","Machine learning progresses",True),
 ("The cat sleeps on the sofa","A feline rests",True),
 ("Rome is the capital of Italy","Bitcoin price rises",False),
 ("Pizza is Italian food","Neural networks for AI",False)
]
pos, neg = [], []
for a,b,is_pos in tests:
    va, vb = embed([a])[0], embed([b])[0]
    sim = float(np.dot(va,vb))
    (pos if is_pos else neg).append(sim)
    badge = '✅' if (sim>0.65 and is_pos) or (sim<0.35 and not is_pos) else '⚠️'
    print(f"{badge} {a[:32]:32} <> {b[:32]:32} | {sim:.3f}")
print('\nAverages -> pos:', np.mean(pos), 'neg:', np.mean(neg), 'separation:', np.mean(pos)-np.mean(neg))

Some weights of MPNetModel were not initialized from the model checkpoint at microsoft/mpnet-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


✅ Base microsoft/mpnet-base | Params 109,734,787
✅ Bitcoin reaches new all-time hig <> BTC price surges                 | 0.720
⚠️ AI is advancing rapidly          <> Machine learning progresses      | 0.583
✅ The cat sleeps on the sofa       <> A feline rests                   | 0.778
✅ Rome is the capital of Italy     <> Bitcoin price rises              | 0.121
✅ Pizza is Italian food            <> Neural networks for AI           | 0.151

Averages -> pos: 0.6933179299036661 neg: 0.13598528504371643 separation: 0.5573326448599497


In [8]:
# STEP 7: ANCHOR SPECIALIZATION + EXPORT
from tqdm import tqdm
import csv
raw.eval()

REPORT_TXT = f"{CHECKPOINT_DIR}/anchor_analysis.txt"
REPORT_CSV = f"{CHECKPOINT_DIR}/anchor_top.csv"

@torch.no_grad()
def analyze_anchors(sample_size=800):
    import random
    sample_size = min(sample_size, len(ds))
    idxs = random.sample(range(len(ds)), sample_size)
    activations = {i: [] for i in range(raw.anchors.shape[0])}
    for i in tqdm(idxs, desc='Anchors'):
        a,b,_ = ds[i]
        enc = tokenizer([a], padding=True, truncation=True, max_length=MAX_LEN, return_tensors='pt').to(device)
        emb, _, _, scores = raw(enc['input_ids'], enc.get('attention_mask'))
        scores = scores[0].cpu().tolist()
        for k,s in enumerate(scores):
            if len(activations[k]) < 30:  # retain limited examples per anchor
                activations[k].append((s,a[:120]))
    with open(REPORT_TXT,'w',encoding='utf-8') as f_txt, open(REPORT_CSV,'w',newline='',encoding='utf-8') as f_csv:
        writer = csv.writer(f_csv); writer.writerow(['anchor','score','text'])
        for anchor, items in activations.items():
            items.sort(key=lambda x: x[0], reverse=True)
            f_txt.write(f"\n==== ANCHOR {anchor} ===\n")
            for s,t in items[:10]:
                f_txt.write(f"{s:.3f} | {t}\n"); writer.writerow([anchor, f"{s:.3f}", t])
    print('Saved anchor reports:', REPORT_TXT, REPORT_CSV)

analyze_anchors()

EXPORT_PATH = f"{CHECKPOINT_DIR}/nexus_64anchors_final.pt"
raw_to_save = raw._orig_mod if hasattr(raw,'_orig_mod') else raw
torch.save({'model_state_dict': raw_to_save.state_dict(), 'config': {'base_model':'microsoft/mpnet-base','nsa_anchors':64,'hash_dim':256,'embedding_dim':768,'matryoshka_dims':DIMS}, 'training_steps': step}, EXPORT_PATH)
print('Final model exported ->', EXPORT_PATH)

Anchors: 100%|██████████| 800/800 [00:08<00:00, 94.42it/s]


Saved anchor reports: /content/drive/MyDrive/nexus_checkpoints_64anchors/anchor_analysis.txt /content/drive/MyDrive/nexus_checkpoints_64anchors/anchor_top.csv
Final model exported -> /content/drive/MyDrive/nexus_checkpoints_64anchors/nexus_64anchors_final.pt


# Task
Analyze the behavior and reactions of the model by performing a more in-depth test using a larger sample of the dataset.

## Generate embeddings for a larger dataset sample

### Subtask:
Select a larger random sample from the dataset and generate embeddings and anchor scores for each text in the sample.


**Reasoning**:
Select a larger sample of the dataset, generate embeddings and anchor scores for both sentences in each pair from the sample, and store them for analysis.



**Reasoning**:
Calculate the mean and standard deviation of anchor scores for both sentence A and sentence B samples across all anchors, and then visualize the distributions.



## Identify top activating examples per anchor

### Subtask:
For each anchor, identify the top examples that cause the highest activation scores.


**Reasoning**:
Initialize the dictionary to store top examples for each anchor and then iterate through the sampled data to populate it.



## Qualitative analysis of top examples

### Subtask:
Manually review the top examples for a subset of anchors to understand what kind of text content or semantic features each anchor is specializing in.


**Reasoning**:
Manually review the top examples for a subset of anchors to understand their specialization. I will select a few anchors based on the previous analysis of mean and standard deviation of anchor scores and print their top examples for manual inspection.



## Visualize anchor relationships

### Subtask:
If possible, visualize the relationships between anchors based on their co-activation patterns or the similarity of the embeddings they are most sensitive to. This could involve techniques like t-SNE or UMAP on anchor vectors or the embeddings of highly activating texts.


**Reasoning**:
Calculate the pairwise cosine similarity between the anchor vectors and then use t-SNE to reduce the dimensionality for visualization.



## Summarize findings

### Subtask:
Summarize the findings from the quantitative analysis of anchor activation distribution and the qualitative analysis of top examples to provide insights into the model's behavior and anchor specialization.


**Reasoning**:
Summarize the findings from the quantitative and qualitative analyses and the t-SNE visualization to provide insights into the model's behavior and anchor specialization.



## Summary:

### Data Analysis Key Findings

*   The mean anchor scores for both sentence A and sentence B samples varied significantly across anchors, ranging from approximately -2.2 to 2.4 for sentence A and -2.3 to 2.5 for sentence B.
*   The standard deviation of anchor scores for individual anchors across the sampled texts was generally lower than the mean scores, ranging from approximately 0.2 to 0.98.
*   Qualitative analysis of top examples suggested that Anchor 30 might be sensitive to political or controversial content, Anchor 1 to more general or neutral content, and Anchor 53 to informal/conversational language or technical vocabulary.
*   The t-SNE visualization of anchor vectors suggests potential relationships or clustering among anchors based on their representation in the model's learned space.

### Insights or Next Steps

*   The combination of quantitative and qualitative analysis provides initial hypotheses about anchor functions, which could be further explored through more extensive manual review or automated text analysis of top examples.
*   Investigating the clusters observed in the t-SNE visualization could reveal groups of anchors that collectively represent specific types of semantic information.


# Task
Suggest ways to improve the model's performance based on the training results and analysis.

## Explore additional training data

### Subtask:
Identify and potentially integrate other relevant datasets for training, such as other NLI datasets (e.g., ANLI) or datasets focused on specific domains or tasks relevant to the intended use of the model.


**Reasoning**:
Outline a plan for integrating other relevant datasets and discuss potential challenges.



## Hyperparameter tuning

### Subtask:
Experiment with different hyperparameters for training, such as learning rates, batch sizes, weight decay, temperature for InfoNCE loss, or the weights of different loss components (main, topic, hash, spec).


**Reasoning**:
Modify the `STEP 4` code cell to include easily adjustable hyperparameters and run a training loop with these parameters.



**Reasoning**:
Run evaluation with the current hyperparameters to assess performance and then modify the hyperparameters in the `HYPERPARAMS` dictionary and re-run the training and evaluation cells to experiment with different settings as requested by the subtask.



## Architecture modifications

### Subtask:
Consider minor modifications to the model architecture, such as adding more layers to the projection heads, trying different activation functions, or experimenting with different base models if appropriate.


**Reasoning**:
Review the current model architecture and propose a minor modification to the projection heads by adding an additional linear layer with ReLU activation.

