In [1]:
### All imports ###
import os
import sys
paths_to_add = ["..", "../.."]
for path in paths_to_add:
    sys_path = os.path.relpath(path)
    if sys_path not in sys.path:  # Check to avoid duplicates
        sys.path.append(sys_path)

import argparse
import itertools
from functools import partial
import copy
import random
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
import pickle

import torch.nn.functional as F

import time as TIME

from _src.datasets import prepare_data
from _src.models import RoPETransformer, RoPEFlashAttention

from torch.utils.data.distributed import DistributedSampler

import matplotlib.pyplot as plt
import seaborn as sns
plt.rcParams.update({"font.size": 20})


### Custom class to use pickle files on cpu ###
import io
class CPU_Unpickler(pickle.Unpickler):
    def find_class(self, module, name):
        if module == 'torch.storage' and name == '_load_from_bytes':
            return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
        else:
            return super().find_class(module, name)


def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

### Helper functions to create task lists ### 
def generate_all_unique_sublists(args):
    all_combinations = list(itertools.product(range(1, args.base), repeat=args.n_var))
    if len(all_combinations) < args.n_tasks:
        raise ValueError("Not enough unique combinations available.")
    selected_combinations = random.sample(all_combinations, args.n_tasks)
    return [list(combination) for combination in selected_combinations]


def get_ood_lists(Ws, args):
    all_combinations = list(itertools.product(range(1, args.base), repeat=args.n_var))
    Ws = set(tuple(W) for W in Ws)
    return list(set(all_combinations) - Ws)


def attach_tasks_with_shared_components(Ws_unique, args):
    """Currently only work"""
    all_possibilities = itertools.product(range(args.p), repeat=len(Ws_unique[0]))
    Ws = copy.deepcopy(Ws_unique)
    for possible_W in all_possibilities:
        for W in Ws_unique:
            for i in range(len(possible_W)):
                if W[i] == possible_W[i] and possible_W not in Ws:
                    Ws.append(list(possible_W))
    return Ws


def parallelogram_tasks_with_shared_components(Ws_unique, args):
    """Currently only work for 2 variables"""
    def generate_lists(a, b, p):
        # Generate unique list combinations given the constraints
        while True:
            x = random.choice([i for i in range(p) if i != a and i != b])
            y = random.choice([i for i in range(p) if i != a and i != b and i != x])

            list1 = [a, y]
            list2 = [x, b]
            list3 = [x, y]

            if list1 not in Ws and list2 not in Ws and list3 not in Ws:
                return [list1, list2, list3]

    Ws = copy.deepcopy(Ws_unique)
    while len(Ws) < 4 * len(Ws_unique):
        for W in Ws_unique:
            if len(Ws) >= 4 * len(Ws_unique):
                break  # Break early if the target size is already reached
            new_Ws = generate_lists(W[0], W[1], args.p)
            Ws.extend(new_Ws)
    
    return Ws

### Parser for all required settings ###
parser = argparse.ArgumentParser(description="Transformer Grokking")
parser.add_argument("--model_name", default="rope_decoder", type=str, help="Encoder or Decoder only Transformers")
parser.add_argument("--one_shot", default=False, type=str2bool, help="One shot or CoT")
parser.add_argument("--device", default="cpu", type=str, help="device")
parser.add_argument("--dtype", default="float32", type=str, help="dtype")
parser.add_argument("--mixed_precision", default=False, type=str2bool, help="Automatic Mixed Precision")
parser.add_argument("--seed", default=1, type=int, help="random seed")
parser.add_argument("--ddp", default=False, type=str2bool, help="DDP or not")
parser.add_argument("--world_size", default=1, type=int, help="World Size")

# Model Settings
parser.add_argument("--n_layer", default=4, type=int, help="Number of Transformer Blocks")
parser.add_argument("--dp", default=0.0, type=float, help="Dropout Probability")
parser.add_argument("--if_ln", default=True, type=str2bool, help="If use LayerNorm or Not")
parser.add_argument("--n_embd", default=512, type=int, help="Embedding Dimension")
parser.add_argument("--n_head", default=4, type=int, help="Number of Heads")
parser.add_argument("--block_size", default=512, type=int, help='maximum length')
parser.add_argument("--act_name", default="relu", type=str, help="activation: relu, gelu, swiglu")
parser.add_argument("--widen_factor", default=4, type=int, help="MLP widening")
parser.add_argument("--mu", default=1.0, type=float, help="Skip connection strength")
parser.add_argument("--weight_tying", default=False, type=str2bool, help="If use weight tying")
parser.add_argument("--dont_decay_embd", default=True, type=str2bool, help="If use weight tying")

# Data
parser.add_argument("--n_tasks", default=64, type=int, help="number of independent tasks")
parser.add_argument("--parallelogram", default=True, type=str2bool, help="Perform parallelogram construction on task vectors or not")
parser.add_argument("--n_var", default=2, type=int, help="number of variables, i.e. dimension of the problem")
parser.add_argument("--data_seed", default=0, type=int, help="random seed for generating datasets")
parser.add_argument("--data_pct", default=80.0, type=float, help="Data Percentage")
parser.add_argument("--task_pct", default=50.0, type=float, help="Task Percentage")
parser.add_argument("--p", default=29, type=int, help="Modulo p")
parser.add_argument("--base", default=29, type=int, help="Represent Numbers in base")
parser.add_argument("--n_point_per_row", default=32, type=int, help="Number of data points per row")
parser.add_argument("--n_point_per_row_gen", default=1, type=int, help="k-shot")
parser.add_argument("--ctx_masked", default=0, type=int, help="Number of first i data points to mask")
parser.add_argument("--encrypted", default=True, type=str2bool, help="Write the task vectors in data or not.")
parser.add_argument("--pos_hint", default=False, type=str2bool, help="Add positional hint or not")
parser.add_argument("--reverse_target", default=False, type=str2bool, help="Reverse the digits order of targets or not")
parser.add_argument("--show_mod", default=False, type=str2bool, help="Add mod p to token or not")
parser.add_argument("--show_seos", default=False, type=str2bool, help="USe SOS and EOS or not")
parser.add_argument("--split_tasks", default=False, type=str2bool, help="Train/Test set have different task vectors or not.")
parser.add_argument("--split_data", default=True, type=str2bool, help="Train/Test set have different datapoints or not.")

# Optimization
parser.add_argument("--optim", default="adamw", type=str, help="Optimizer: adamw or sgd")
parser.add_argument("--s", default=0.0, type=float, help="s=0 for SP, 1 for muP like attention. Use 0.0 only for now.")
parser.add_argument("--bs", default=1024, type=int, help="Batchsize")
parser.add_argument("--eval_bs", default=1024, type=int, help="Batchsize for Evaluation")
parser.add_argument("--lr", default=1.5e-4, type=float, help="Learning Rate")
parser.add_argument("--n_cycles", default=1, type=int, help="Cycles of Scheduler")
parser.add_argument("--wd", default=5.0, type=float, help="Weight Decay")
parser.add_argument("--beta1", default=0.9, type=float, help="Beta 1 for AdamW")
parser.add_argument("--beta2", default=0.98, type=float, help="Beta 2 for AdamW")
parser.add_argument("--eps", default=1e-8, type=float, help="Eps for AdamW")
parser.add_argument("--momentum", default=0.9, type=float, help="Momentum for SGD")
parser.add_argument("--ckpt_step", default=160000, type=int, help="Training Epochs")
parser.add_argument("--steps", default=200000, type=int, help="Training Epochs")
parser.add_argument("--warmup_steps", default=10000, type=int, help="Warmup Epochs")
parser.add_argument("--lr_decay", default='cosine', type=str, help="If Use Scheduler")
parser.add_argument("--steps_per_record", default=1000, type=int, help="Save Results")
parser.add_argument("--reshuffle_step", default=1, type=int, help="Save Results")
parser.add_argument("--no_of_steps", default=1, type = int, help='Number of examples')


# Inference
parser.add_argument("--train_set", default=False, type=str2bool, help="Use training inputs or val inputs")
parser.add_argument("--n_measure", default=3, type=int, help="How many batches to average.")
parser.add_argument("--savefig", default=True, type=str2bool, help="Save Figure")
parser.add_argument("--no_of_dis_el", default = 1, type = int, help='Number of distinct elements')
parser.add_argument("--ood_tasks",default=False,type=str2bool,help='Whether to use ID (False) or OOD (True) tasks')
#parser.add_argument("--no_o")


args, unknown = parser.parse_known_args()
assert args.show_seos == False

device = torch.device(args.device)
if args.mixed_precision is True:
    assert args.dtype in ['float16', 'bfloat16']
    assert 'cuda' in args.device
    if args.dtype == 'float16':
        args.dtype = torch.float16
    else:
        args.dtype = torch.bfloat16
else:
    torch.set_float32_matmul_precision('high')
    args.dtype = torch.float32

### Set random seeds manually ###
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)

Ws = generate_all_unique_sublists(args)
Ws_og = Ws
print(len(Ws_og))
if args.parallelogram is True:
    Ws = parallelogram_tasks_with_shared_components(Ws, args)
check = list(set(tuple(W) for W in Ws))

args.pre_train_n_tasks = len(Ws)

Ws_ood = get_ood_lists(Ws, args)

random.shuffle(Ws_ood)

# Choose 16 tasks, half are OOD tasks, half are ID.
no_of_tasks_tc = 16
Ws_i = [Ws[i] for i in range(no_of_tasks_tc//2) ]
Ws_o = [Ws_ood[i] for i in range(no_of_tasks_tc//2-1)]

Ws = Ws_i + [(1,1)] +  Ws_o

args.n_tasks = len(Ws)
# print(Ws, args.n_tasks)

if args.n_tasks == 1:
    assert args.split_tasks == False

### Defining datasets, data, data-loader. ###
    
train_set, val_set, tokenizer = prepare_data(args, Ws)
# print(train_set[:2, :])
# exit()
original_n_train_row = train_set.size(0)
original_n_val_row = val_set.size(0)
args.vocab_size = tokenizer.__len__()
args.max_digits = tokenizer.max_digits
args.max_digits = 2 * args.max_digits if args.pos_hint is True else args.max_digits
args.dim = args.max_digits * (len(Ws[0]) + 1)


"""Make copies of data"""
if args.split_tasks is True:
    task_rows = round(args.n_tasks * (args.task_pct / 100.0))
    train_set = train_set.view(task_rows, len(train_set) // task_rows, -1)
    val_set = val_set.view((args.n_tasks - task_rows), len(val_set) // task_rows, -1)
else:
    task_rows = args.n_tasks
    train_set = train_set.view(task_rows, len(train_set) // task_rows, -1)
    val_set = val_set.view(args.n_tasks, len(val_set) // task_rows, -1)

class CustomDataset(Dataset):
    def __init__(self, dataset, bs, args):
        self.dataset = dataset.transpose(0, 1)
        self.n_data, self.n_task, self.dim = self.dataset.shape
        self.bs = bs
        self.args = args
        
    def __len__(self):
        return self.n_data * self.args.n_point_per_row * 5 * self.args.n_measure # To ensure we don't have to restart dataloader.

    def __getitem__(self, idx):
        step_x = self.dataset[idx % self.n_data] # (n_tasks, dim)
        return step_x


def custom_collate_fn(batch, bs, args):
    inputs = torch.stack(batch, dim=1)
    idx = torch.randperm(inputs.size(1))
    inputs = inputs[torch.arange(args.n_tasks)[:, None], idx[None, :]]
    targets = inputs.clone() # (n_tasks, bs * ctx_length, dim)
    targets[:, :, :-args.max_digits] = -100 # Mask the unpredictable part
    return inputs.view(bs, -1), targets.view(bs, -1)  # (bs, ctx_length * dim)


def get_dataloader(dataset, bs, args):
    custom_dataset = CustomDataset(dataset, bs, args)
    collate_fn = partial(custom_collate_fn, bs=bs, args=args) # collate_fn should only take one input
    if args.ddp is False:
        g = torch.Generator()
        g.manual_seed(args.seed + 38493483)
        dataloader = DataLoader(custom_dataset, batch_size=(bs * args.n_point_per_row // args.n_tasks), num_workers=0, collate_fn=collate_fn, drop_last=True, generator=g) 
    else:
        sampler = DistributedSampler(custom_dataset, shuffle=False)
        dataloader = DataLoader(custom_dataset, batch_size=(bs * args.n_point_per_row // args.n_tasks) // args.world_size, drop_last=True, num_workers=0, collate_fn=collate_fn, sampler=sampler) 
    return dataloader

args.eval_bs = len(Ws)
train_iter = get_dataloader(train_set, args.eval_bs, args)
val_iter = get_dataloader(val_set, args.eval_bs, args)

if args.train_set is True:
    iter_to_use = get_dataloader(train_set, args.eval_bs, args)
else:
    iter_to_use = get_dataloader(val_set, args.eval_bs, args)

64


In [None]:
### Checkpoint name ###
ckpt_path = f'../ckpts/d{args.n_layer}_h{args.n_head}_embd{args.n_embd}/noembd{args.dont_decay_embd}_parale{args.parallelogram}_{args.model_name}_p{args.p}_base{args.base}_row{args.n_point_per_row}_ntask{args.pre_train_n_tasks}_nvar{args.n_var}_dsplit{args.split_data}_dfrac{args.data_pct:.1f}_{args.act_name}_n{args.n_embd}_h{args.n_head}_d{args.n_layer}_lctx{args.block_size}_I{args.seed}_dI{args.data_seed}_{args.optim}_bs{args.bs}_t{args.ckpt_step:d}_T{args.steps:d}_Tw{args.warmup_steps:d}_Trshf{args.reshuffle_step}_lr{args.lr:0.2e}_wd{args.wd:.2e}.pth'

model = RoPETransformer(RoPEFlashAttention, args).to(device=device)
model.load_state_dict(torch.load(ckpt_path, map_location='cpu'), strict=False)
model.to(device)

# Label corruption (at random locations) 

In [2]:
######## LABEL CORRUPTION AT RANDOM LOCATIONS ########


#### Define data-loader for random label corruption ####

class CustomDatasetCorruption(Dataset):
    def __init__(self, dataset, bs, args):
        self.dataset = dataset.transpose(0, 1)
        self.n_data, self.n_task, self.dim = self.dataset.shape
        self.bs = bs
        self.args = args
        
    def __len__(self):
        return self.n_data * self.args.n_point_per_row * 5 * self.args.n_measure # To ensure we don't have to restart dataloader.

    def __getitem__(self, idx):
        step_x = self.dataset[idx % self.n_data] # (n_tasks, dim)
        return step_x


def custom_collate_fn_corrupt(batch, bs, args):
    inputs = torch.stack(batch, dim=1)
    idx = torch.randperm(inputs.size(1))
    inputs = inputs[torch.arange(args.n_tasks)[:, None], idx[None, :]]
    targets = copy.deepcopy(inputs) # (n_tasks, bs * ctx_length, dim)
    targets[:, :, :-args.max_digits] = -100 # Mask the unpredictable part
    
    if args.no_of_corr > 0:
        to_corrupt = torch.randperm(args.n_point_per_row)[:args.no_of_corr]
        inputs[:,to_corrupt,args.to_corrupt] = (inputs[:,to_corrupt, args.to_corrupt]+torch.randint(1,args.p,(inputs.size(0), args.no_of_corr)))%args.p
    
    return inputs.view(bs, -1), targets.view(bs, -1)  # (bs, ctx_length * dim)


def get_dataloader_corrupt(dataset, bs, args):
    custom_dataset = CustomDatasetCorruption(dataset, bs, args)
    collate_fn = partial(custom_collate_fn_corrupt, bs=bs, args=args) # collate_fn should only take one input
    if args.ddp is False:
        g = torch.Generator()
        g.manual_seed(args.seed + 38493483)
        dataloader = DataLoader(custom_dataset, batch_size=(bs * args.n_point_per_row // args.n_tasks), num_workers=0, collate_fn=collate_fn, drop_last=True, generator=g) 
    else:
        sampler = DistributedSampler(custom_dataset, shuffle=False)
        dataloader = DataLoader(custom_dataset, batch_size=(bs * args.n_point_per_row // args.n_tasks) // args.world_size, drop_last=True, num_workers=0, collate_fn=collate_fn, sampler=sampler) 
    return dataloader

Ws = [(1,1),(1,2)]
args.n_tasks = len(Ws)
train_set, val_set, tokenizer = prepare_data(args, Ws)
# print(train_set[:2, :])
# exit()
original_n_train_row = train_set.size(0)
original_n_val_row = val_set.size(0)
args.vocab_size = tokenizer.__len__()
args.max_digits = tokenizer.max_digits
args.max_digits = 2 * args.max_digits if args.pos_hint is True else args.max_digits
args.dim = args.max_digits * (len(Ws[0]) + 1)


"""Make copies of data"""
if args.split_tasks is True:
    task_rows = round(args.n_tasks * (args.task_pct / 100.0))
    train_set = train_set.view(task_rows, len(train_set) // task_rows, -1)
    val_set = val_set.view((args.n_tasks - task_rows), len(val_set) // task_rows, -1)
else:
    task_rows = args.n_tasks
    train_set = train_set.view(task_rows, len(train_set) // task_rows, -1)
    val_set = val_set.view(args.n_tasks, len(val_set) // task_rows, -1)

args.eval_bs = len(Ws)

if args.train_set is True:
    corrupt_iter_to_use = get_dataloader_corrupt(train_set, args.eval_bs, args)
else:
    corrupt_iter_to_use = get_dataloader_corrupt(val_set, args.eval_bs, args)

if args.train_set is True:
    iter_to_use = get_dataloader(train_set, args.eval_bs, args)
else:
    iter_to_use = get_dataloader(val_set, args.eval_bs, args)
    
### Ws[no_of_tasks_tc//2] is (1,1)
# print(Ws[no_of_tasks_tc//2])

### TEST CODE TO SEE IF DATALOADER WORKS ###
args.n_measure = 1
args.n_point_per_row = 4
args.no_of_corr = 2
args.to_corrupt = 2 ### which one to corrupt, 0 -> x, 1 -> y, 2 -> f.

with torch.inference_mode():
    print(len(iter_to_use))
    for t, (x,y) in enumerate(iter_to_use):
        print("The seq is ", x[0,:])
        if t >= 0:
            break
    for t, (x,y) in enumerate(corrupt_iter_to_use):
        print("The corrupted seq is ", x[0,:])
        if t >= 0:
            break


105
The seq is  tensor([ 8, 28,  7, 12, 23,  6, 21, 12,  4, 14,  4, 18, 28, 17, 16, 15,  9, 24,
        11,  9, 20, 26, 27, 24,  6,  8, 14, 21, 28, 20, 27,  1, 28,  3,  7, 10,
        28,  5,  4,  8, 16, 24, 25,  9,  5,  9, 14, 23, 20,  9,  0,  4, 23, 27,
         7, 23,  1,  9, 25,  5, 15, 26, 12,  9,  8, 17,  0, 19, 19, 20, 18,  9,
        10,  6, 16, 20, 27, 18,  6, 10, 16, 15, 15,  1, 18,  6, 24,  9, 19, 28,
        17, 22, 10, 11,  4, 15])
The corrupted seq is  tensor([ 4, 23,  8, 20, 27, 18,  8, 28,  7,  9, 14,  4, 15, 15,  1,  6,  8, 14,
         7, 23,  1, 15,  9, 24, 20,  9,  0, 21, 28, 20, 20, 18,  9, 18,  6, 24,
         9,  8, 17, 17, 22, 10, 14,  4, 18, 11,  4, 15,  6, 10, 16, 28,  5,  4,
         9, 19, 28, 21, 12,  4, 10,  6, 16,  0, 19, 19,  8, 16, 24, 26, 27, 24,
        25,  9,  5, 11,  9, 20,  3,  7, 10,  9, 25,  5, 12, 23,  6, 27,  1, 28,
        15, 26, 12, 28, 17, 16])


In [3]:
@torch.inference_mode()
def measure_perpos_accloss_full_int_preds(model, val_iter, args, device, n_measure = 1):
    """Measure per position accuracy for one batch, modular arithmetic
    """
    ctx = nullcontext() if 'mps' in args.device else torch.autocast(device_type=args.device, dtype=args.dtype, enabled=args.mixed_precision)
    model.eval()
    model.eval()
    
    acc_records, loss_records, logits_records = [], [], []
    tgts = []
    
    # acc_record = torch.zeros((args.eval_bs, args.new_seq_len+1), device=device, dtype=args.dtype)
    # loss_record = torch.zeros((args.eval_bs, args.new_seq_len+1), device=device, dtype=args.dtype)
    t1 = TIME.time()
    
    for t, (x, y) in enumerate(val_iter):
        print("At pos ",t)
        # print(TIME.time()-t1)
        t1 = TIME.time()
        x = x[:, :-1].contiguous().to(device)
        y = y[:, 1:].contiguous().to(device)
        print((x[:,2::3]!=y[:,1:-1:3]).float().mean())
        if t >= (n_measure):
            break
        # print("The seq is ", x[:2,:12])
        # print("X shape : ", x.shape)
        

        losses_ints = []
        acc_ints = []
        logits_ints = []
        
        tgts.append(y)
        
        with ctx:
            logits, qkv_list, input_list, int_list = model.record(x)
            # print("logits : ",logits.shape)
            
            for i in range(args.n_layer-1):
                x1 = model.transformer.ln_f(input_list[i])
                logits_int = model.scale * model.lm_head(x1)
                logits_ints.append(logits_int)
                loss_int = F.cross_entropy(logits_int.view(-1, logits_int.size(-1)), y.view(-1), reduction='none')
                loss_int = loss_int.reshape(logits_int.size(0),-1)
                losses_ints.append(loss_int[:, (args.dim-2)::args.dim])
                pred_int = logits_int.argmax(-1)
                correct_mask_int = (pred_int[:, (args.dim-2)::args.dim] == y[:, (args.dim-2)::args.dim])
                acc_ints.append(correct_mask_int.float())
                
            losses = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1), reduction='none')
            losses = losses.reshape(logits.size(0), -1)
            pred = logits.argmax(-1)
            logits_ints.append(logits)
        
            # print("preds and tgts shape: ",pred.shape, y.shape)
            
            correct_mask = (pred[:, (args.dim-2)::args.dim] == y[:, (args.dim-2)::args.dim])
            # print(correct_mask.shape)
            
            # acc_record += correct_mask
            # loss_record += losses[:, (args.dim-2)::args.dim]
        
            acc_ints.append(correct_mask.float())
            losses_ints.append(losses[:, (args.dim-2)::args.dim])
        
            acc_records.append(acc_ints)
            loss_records.append(losses_ints)
            logits_records.append(logits_ints)

    return acc_records, loss_records, logits_records, tgts

In [None]:
### Last token accuracy and/or loss as a function of sequence length and amount of label corruption ###

no_of_tasks_tc = 64
args.max_ctx= 32
args.n_measure = 20
args.to_corrupt = 2

Ws = generate_all_unique_sublists(args)
print(len(Ws))
if args.parallelogram is True:
    Ws = parallelogram_tasks_with_shared_components(Ws, args)
check = list(set(tuple(W) for W in Ws))
# print('train Ws: \n', Ws, len(Ws), len(check))
args.pre_train_n_tasks = len(Ws)

Ws_i = [Ws[i] for i in range(no_of_tasks_tc//2) ]
Ws_o = [Ws_ood[i] for i in range(no_of_tasks_tc//2)]

Ws = Ws_i + Ws_o
args.n_tasks = len(Ws)
print(Ws)
print(len(Ws))

train_set, val_set, tokenizer = prepare_data(args, Ws)
# print(train_set[:2, :])
# exit()
original_n_train_row = train_set.size(0)
original_n_val_row = val_set.size(0)
args.vocab_size = tokenizer.__len__()
args.max_digits = tokenizer.max_digits
args.max_digits = 2 * args.max_digits if args.pos_hint is True else args.max_digits
args.dim = args.max_digits * (len(Ws[0]) + 1)


"""Make copies of data"""
if args.split_tasks is True:
    task_rows = round(args.n_tasks * (args.task_pct / 100.0))
    train_set = train_set.view(task_rows, len(train_set) // task_rows, -1)
    val_set = val_set.view((args.n_tasks - task_rows), len(val_set) // task_rows, -1)
else:
    task_rows = args.n_tasks
    train_set = train_set.view(task_rows, len(train_set) // task_rows, -1)
    val_set = val_set.view(args.n_tasks, len(val_set) // task_rows, -1)



#### Quick helper function to get the average last token accuracy and loss averaged over trials ###

def helper_last_token_acc_loss(acc_records, loss_records):
    acc_avg = np.zeros(acc_records[0][-1].size(0))
    loss_avg = np.zeros(loss_records[0][-1].size(0))
    for i in range(len(acc_records)):
        acc_avg = np.add(acc_avg,acc_records[i][-1][:,-1])
        loss_avg = np.add(loss_avg,loss_records[i][-1][:,-1])
    # print(loss_avg)
    return acc_avg/len(acc_records), loss_avg/len(loss_records)

args.eval_bs = len(Ws)

acc_avgs = np.full((args.eval_bs, args.max_ctx, args.max_ctx),np.nan)
loss_avgs = np.full((args.eval_bs, args.max_ctx, args.max_ctx),np.nan)

for ctx_len in range(1,args.max_ctx):
    args.n_point_per_row = ctx_len
    for j in range(ctx_len):
        args.no_of_corr = j

        print( float(j)/ctx_len)
        if args.train_set is True:
            corrupt_iter_to_use = get_dataloader_corrupt(train_set, args.eval_bs, args)
        else:
            corrupt_iter_to_use = get_dataloader_corrupt(val_set, args.eval_bs, args)
        
        with torch.inference_mode():
            print(len(corrupt_iter_to_use))
            acc_records, loss_records, logits_records, tgts = measure_perpos_accloss_full_int_preds(model, corrupt_iter_to_use,args, device, n_measure=args.n_measure)
            acc_avg, loss_avg = helper_last_token_acc_loss(acc_records, loss_records)
            # print(acc_avg.shape)
            
            acc_avgs[:, ctx_len, j] = acc_avg
            loss_avgs[:, ctx_len, j] = loss_avg




In [226]:
imp_datas = {"acc" : acc_avgs, "loss":loss_avgs, "wts":Ws}
        
fn = f"./data/lc_accloss_b_n_s_x_cond_maxc_{args.max_ctx}_tr{args.n_measure}_Ws{len(Ws)}_n{args.n_embd}_h{args.n_head}_d{args.n_layer}.pkl"    
        
with open(fn, 'wb') as fp:
    pickle.dump(imp_datas,fp)

In [224]:
def plot_ctx_label_corr_random(acc_avgs, vmin, vmax, cbar_name = 'accuracy'):
    '''
    Parameters
    ----------
    acc_avgs : array of shape (# of tasks, max_cc, max_cc)
        Average accuracy (loss) 
    vmin, vmax : int, int
        vmin & vmax settings for ax.imshow
    cbar_name : str
        Name for colorbar and/or plot name

    Returns
    -------
    None.

    '''

    args.n_rows = 4
    args.n_cols = (no_of_tasks_tc)//args.n_rows

    fig, axs = plt.subplots(args.n_rows, args.n_cols, figsize=(26, 15), constrained_layout=True)

    if vmin is None:
        vmin = np.nanmin(acc_avgs)*0.9
    if vmax is None:
        vmax = np.nanmax(acc_avgs)*1.1
    
    handles, labels = [], []
    for i in range(args.n_rows):
        for j in range(args.n_cols):
            ax = axs[i][j]
            
            ti = i*args.n_cols + j
            ax.set_title(f"W : {Ws[ti]}")
            to_plot = acc_avgs[ti,:,:]
            
            
            im0 = ax.imshow(to_plot, vmin=vmin, vmax=vmax,cmap = 'cividis')
            
            ax.set_ylabel('Ctx')
            ax.set_xlabel('# of Wrong labels')

            

    fig.suptitle(f'Label corruption')
    cbar = fig.colorbar(im0, ax=axs.ravel().tolist(), fraction=0.046, pad=0.04)
    cbar.set_label(cbar_name)  # Example label for the colorbar
    # fig.suptitle()    
    if args.savefig is True:
        fig.savefig(f'label_corr/label_corr_{cbar_name}_Ws{len(Ws)}_d_{args.n_layer}_h_{args.n_head}.pdf', format = 'pdf')
    plt.show()
    plt.close()


In [None]:
fn = './data/lc_accloss_b_n_s_x_cond_maxc_32_tr20_n512_h4_d2.pkl'

with open(fn, 'rb') as fp:
    imp_datas = pickle.load(fp)

acc_avgs , loss_avgs = imp_datas['acc'], imp_datas['loss']

plot_ctx_label_corr_random(acc_avgs, vmin= 0.0, vmax = 1.0,cbar_name = 'accuracy')
plot_ctx_label_corr_random(loss_avgs, vmin= None, vmax = None,cbar_name= 'loss')


In [None]:
### Plot average over tasks ###

SIZE = 40

acc_avg_tasks = acc_avgs.mean(0)
loss_avg_tasks = loss_avgs.mean(0)

args.n_cols = 2
args.n_rows = 1

fig, axs = plt.subplots(args.n_rows, args.n_cols, figsize=(28, 15), constrained_layout=True)

vmin = 0.0
vmax = 1.0
    
ax = axs[0]

to_plot = acc_avg_tasks[:,:].T
im0 = ax.imshow(to_plot, vmin=vmin, vmax=vmax,cmap = 'cividis')
            
ax.set_xlabel('Shot $i$',size=SIZE)
ax.set_xticks(np.arange(to_plot.shape[1],step=4),labels=np.arange(to_plot.shape[1],step=4),size=SIZE)
ax.set_ylabel('# of Wrong labels',size=SIZE)
ax.set_yticks(np.arange(to_plot.shape[1],step=4),labels=np.arange(to_plot.shape[1],step=4),size=SIZE)
ax.invert_yaxis()
cbar_name = 'Acc'
cbar = fig.colorbar(im0, ax=ax, fraction=0.046, pad=0.04)
cbar.set_label(cbar_name,size = SIZE)  # Example label for the colorbar
cbar.ax.tick_params(labelsize=SIZE)


vmin = np.nanmin(loss_avgs)*0.9
vmax = np.nanmax(loss_avgs)*1.1
cbar_name = 'Loss'

ax = axs[1]

to_plot = loss_avg_tasks[:,:].T
im0 = ax.imshow(to_plot, vmin=vmin, vmax=vmax,cmap = 'inferno')
            
ax.set_xlabel('Shot $i$',size=SIZE)
ax.set_xticks(np.arange(to_plot.shape[1],step=4),labels=np.arange(to_plot.shape[1],step=4),size=SIZE)
ax.set_ylabel('# of Wrong labels',size=SIZE)
ax.set_yticks(np.arange(to_plot.shape[1],step=4),labels=np.arange(to_plot.shape[1],step=4),size=SIZE)
cbar = fig.colorbar(im0, ax=ax, fraction=0.046, pad=0.04)
cbar.set_label(cbar_name,size = SIZE)  # Example label for the colorbar
cbar.ax.tick_params(labelsize=SIZE)
ax.invert_yaxis()
# fig.suptitle(f'Label corruption, avg over tasks')
if args.savefig is True:
    fig.savefig(f'label_corr/label_corr_avg_tasks_Ws_d_2_h_4.pdf', format = 'pdf')
plt.show()
plt.close()



# Label corruption at single location

In [207]:
def custom_collate_fn_corrupt_single(batch, bs, args):
    inputs = torch.stack(batch, dim=1)
    idx = torch.randperm(inputs.size(1))
    inputs = inputs[torch.arange(args.n_tasks)[:, None], idx[None, :]]
    targets = copy.deepcopy(inputs) # (n_tasks, bs * ctx_length, dim)
    targets[:, :, :-args.max_digits] = -100 # Mask the unpredictable part
    
    # print(targets.shape)
    
    # to_corrupt = torch.randperm(args.n_point_per_row)[:args.no_of_corr]
    
    inputs[:,args.corrupt_loc,args.to_corrupt] = (inputs[:,args.corrupt_loc, args.to_corrupt]+torch.randint(1,args.p,(inputs.size(0), )))%args.p
    
    return inputs.view(bs, -1), targets.view(bs, -1)  # (bs, ctx_length * dim)


def get_dataloader_corrupt_single(dataset, bs, args):
    custom_dataset = CustomDatasetCorruption(dataset, bs, args)
    collate_fn = partial(custom_collate_fn_corrupt_single, bs=bs, args=args) # collate_fn should only take one input
    if args.ddp is False:
        g = torch.Generator()
        g.manual_seed(args.seed + 38493483)
        bss = (bs * args.n_point_per_row // args.n_tasks)
        print(type(bss), type(bs), bss,bs)
        dataloader = DataLoader(custom_dataset, batch_size=(bs * args.n_point_per_row // args.n_tasks), num_workers=0, collate_fn=collate_fn, drop_last=True, generator=g) 
    else:
        sampler = DistributedSampler(custom_dataset, shuffle=False)
        dataloader = DataLoader(custom_dataset, batch_size=(bs * args.n_point_per_row // args.n_tasks) // args.world_size, drop_last=True, num_workers=0, collate_fn=collate_fn, sampler=sampler) 
    return dataloader


In [None]:
args.n_measure = 3
sl = 32

args.to_corrupt = 2
### pos = 0 : corrupt x's
### pos = 1 : corrupt y's
### pos = 2 : corrupt f's
args.train_set = False
datas = {}
for seq_len in range(1,sl+1):
    args.n_point_per_row = seq_len
    for i in range(seq_len):
        # args.new_seq_len = seq_len
        args.corrupt_loc = i
        print("-"*100)
        print("Seq len : ", seq_len)
        print("-"*100)
        if args.train_set is True:
            corrupt_iter_to_use = get_dataloader_corrupt_single(train_set, args.eval_bs, args)
        else:
            corrupt_iter_to_use = get_dataloader_corrupt_single(val_set, args.eval_bs, args)
    
        with torch.inference_mode():
            print(len(corrupt_iter_to_use))
            step = 0
            for t, (x,y) in enumerate(corrupt_iter_to_use):
                x = x[:, :-1].contiguous().to(device)
                y = y[:, 1:].contiguous().to(device)
                # print(x[0,:])
                # print(y[0,:])
                step+=1
                print("Trial : ", step)
                # print((x[:,pos+s*(args.n_var+1)]==xc[:,pos+s*(args.n_var+1)]).float().mean())
            
                logits = model(x)
        
                losses = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1), reduction='none')
                losses = losses.reshape(logits.size(0), -1)
                pred = logits.argmax(-1)
                loss_record = losses[:, (args.dim-2)::args.dim]
                # print("preds and tgts shape: ",pred.shape, y.shape)
        
                acc_record = (pred[:, (args.dim-2)::args.dim] == y[:, (args.dim-2)::args.dim]).float()
                # print(acc_record.mean())
                # print(loss_record.mean())
                
                datas[f"s:{seq_len}, tr:{step}, loc:{args.corrupt_loc}"] = (x,y,logits, pred, acc_record, loss_record)
                
                # datas[f"s:{seq_len}, tr:{step}"].append()
                if step >= args.n_measure:
                    break
        



In [234]:
pos=2
fn = f"./data/lc_sin_b_n_s_x_pos{pos}_sl{sl}_trs{args.n_measure}_n{args.n_embd}_d{args.n_layer}_h{args.n_head}_p{args.p}_base{args.base}_ntask{args.n_tasks}_nvar{args.n_var}_dsplit{args.split_data}_ood{args.ood_tasks}_tspl{args.train_set}.pkl"    
        
with open(fn, 'wb') as fp:
    pickle.dump(datas,fp)

In [None]:
### First check the mask for the initial case, and pick a case where it's correct.
### Then check what happens to the mask as we corrupt all the previous cases.

fn = f"./data/lc_sin_b_n_s_x_pos{pos}_sl{sl}_trs{args.n_measure}_n{args.n_embd}_d{args.n_layer}_h{args.n_head}_p{args.p}_base{args.base}_ntask{args.n_tasks}_nvar{args.n_var}_dsplit{args.split_data}_ood{args.ood_tasks}_tspl{args.train_set}.pkl"    
        
with open(fn, 'rb') as fp:
    datas = pickle.load(fp)

# print(datas.keys())
step = 1 ### trial no.
ic = np.random.randint(no_of_tasks_tc)
seq_len = 31
locc = 0
xc,yc,logits, pred, acc_record, loss_record = datas[f"s:{seq_len}, tr:{step}, loc:{locc}"]

plt.plot(acc_record[ic,:],marker=".")

In [211]:
def label_corr_exp(max_cc, step,ic=0,pos=0):
    
    fn = f"./data/lc_sin_b_n_s_x_pos{pos}_sl{sl}_trs{args.n_measure}_n{args.n_embd}_d{args.n_layer}_h{args.n_head}_p{args.p}_base{args.base}_ntask{args.n_tasks}_nvar{args.n_var}_dsplit{args.split_data}_ood{args.ood_tasks}_tspl{args.train_set}.pkl"    
            
    with open(fn, 'rb') as fp:
        datas = pickle.load(fp)
    
    
    mask = np.full((max_cc-1,max_cc), np.nan)

    for cc in range(1,max_cc):
        for s in range(cc):
            xc,yc,logits, pred, acc_record, loss_record = datas[f"s:{cc}, tr:{step}, loc:{s}"]
            mask[cc-1,s] = (acc_record[ic,cc-1])
    
    
    args.n_rows = 1
    args.n_cols = 1

    fig, axs = plt.subplots(args.n_rows, args.n_cols, figsize=(15, 15), constrained_layout=True)
    
    ax = axs

    vmin=0.0
    vmax= 1.0
    cbar_name = "Mask"
    im0 = ax.imshow(mask, cmap = 'viridis',vmin=vmin,vmax=vmax)
    cbar = fig.colorbar(im0, ax=ax, fraction=0.046, pad=0.04)
    #cbar.set_ticks(np.linspace(vmin, vmax, num=30))
    cbar.set_label(cbar_name)

    ax.set_ylabel("Len of seq")
    ax.set_xlabel("Label corrupted")

    fig.suptitle(f'{Ws[ic]}, Trial : {step}, Pos: {pos} \n X :{xc[0,:]}')
    if args.savefig is True:
        fig.savefig(f'./label_corr/label_corr_single_pos{pos}_maxcc_{max_cc}_tr{step}_Ws{Ws[ic]}_b_n_s_{args.model_name}_p{args.p}_row{args.n_point_per_row}_ntask{args.n_tasks}_nvar{args.n_var}_dsplit{args.split_data}_tspl{args.train_set}_ood{args.ood_tasks}_dfrac{args.data_pct:.1f}_{args.act_name}_n{args.n_embd}_h{args.n_head}_d{args.n_layer}_I{args.seed}_dI{args.data_seed}_{args.optim}_bs{args.bs}_t{args.ckpt_step:d}_T{args.steps:d}_Tw{args.warmup_steps:d}_lr{args.lr:0.2e}_wd{args.wd:.2e}.pdf', format='pdf')
    plt.show()
    plt.close()


In [None]:
pos = 2
for step in range(1, args.n_measure+1):
    max_cc = 32
    label_corr_exp(max_cc, step,pos=pos)

In [None]:
### Better plots for paper ### 

def label_corr_exp2(max_cc, step,sl,no_tasks = 20,pos=2):
    
    fn = f"./data/lc_sin_b_n_s_x_pos{pos}_sl{sl}_trs{args.n_measure}_n{args.n_embd}_d{args.n_layer}_h{args.n_head}_p{args.p}_base{args.base}_ntask{args.n_tasks}_nvar{args.n_var}_dsplit{args.split_data}_ood{args.ood_tasks}_tspl{args.train_set}.pkl"    
            
    with open(fn, 'rb') as fp:
        datas = pickle.load(fp)
    
    mask = np.full((no_tasks, max_cc-1,max_cc-1), np.nan)

    for cc in range(1,max_cc):
        for s in range(cc-1):
            xc,yc,logits, pred, acc_record, loss_record = datas[f"s:{cc}, tr:{step}, loc:{s}"]
            mask[:no_tasks,cc-1,s] = (acc_record[:no_tasks,cc-1])
    
    xp = torch.cat((xc, yc[:,-1].unsqueeze(1)), dim=-1)
    # print(xp)
    # print(xp.shape)
    xp = np.array(xp.view(xp.size(0),-1,args.n_var+1))
    print(xp.shape)
    args.n_rows = 4
    args.n_cols = no_tasks//args.n_rows

    sns.set_theme(style = 'whitegrid')

    # Apply the seaborn-whitegrid style
    # plt.style.use('seaborn-whitegrid')

    fig, axs = plt.subplots(args.n_rows, args.n_cols, figsize=(25, 20), constrained_layout=True)
    vmin=0.0
    vmax= 1.0
    cbar_name = "Acc"
    
    
    for i in range(args.n_rows):
        for j in range(args.n_cols):
            ax = axs[i][j]
            ti = i*args.n_cols + j
            top = mask[ti,:,:]

            im0 = sns.heatmap(top, ax=ax, cmap = 'viridis',vmin=vmin,vmax=vmax,cbar=False)

            if(i==args.n_rows-1):
                ax.set_xlabel("$z'_j$")
            if (j == 0):
                ax.set_ylabel("Shot $i$")

            ax.set_title(f'W : {Ws[ti]}')

            ytick_labels = [tuple(xp[ti, idx]) for idx in range(xp.shape[1])]
            # print(xtick_labels)
            ax.set_yticks(np.arange(xp.shape[1])+1)
            ax.set_yticklabels(ytick_labels, rotation=0, fontsize=8)

    # print(axs[0][0].collections)
    
    cbar = fig.colorbar(axs[0][0].collections[0], ax=axs.ravel().tolist(), fraction=0.046, pad=0.04)
    cbar.set_ticks(np.linspace(vmin, vmax, num=2))
    cbar.set_label(cbar_name)
    

    print(mask.shape)

    

    # fig.suptitle(f'{Ws[ic]}, Trial : {step}, Pos: {pos} \n X :{xc[0,:]}')
    if args.savefig is True:
        fig.savefig(f'./label_corr/label_corr_single_pos{pos}_maxcc_{max_cc}_tr{step}_Ws{Ws[ic]}_b_n_s_{args.model_name}_p{args.p}_row{args.n_point_per_row}_ntask{args.n_tasks}_nvar{args.n_var}_dsplit{args.split_data}_tspl{args.train_set}_ood{args.ood_tasks}_dfrac{args.data_pct:.1f}_{args.act_name}_n{args.n_embd}_h{args.n_head}_d{args.n_layer}_I{args.seed}_dI{args.data_seed}_{args.optim}_bs{args.bs}_t{args.ckpt_step:d}_T{args.steps:d}_Tw{args.warmup_steps:d}_lr{args.lr:0.2e}_wd{args.wd:.2e}.pdf', format='pdf')
    plt.show()
    plt.close()

pos = 2
for step in range(1, 2):
    max_cc = 32
    label_corr_exp2(max_cc, step,sl=32,pos=pos)


In [None]:
### Final plot for paper ###

def label_corr_exp3(max_cc,sl, step=1,no_tasks = 20,pos=2):

    SIZE = 36
    FRAC = 0.6

    cmap = sns.color_palette(['black', '#FFFDD0'])


    args.n_rows = 1
    args.n_cols = 4

    sns.set_theme(style = 'whitegrid')

    # Apply the seaborn-whitegrid style
    # plt.style.use('seaborn-whitegrid')

    fig, axs = plt.subplots(args.n_rows, args.n_cols, figsize=(53, 13), constrained_layout=True)
    vmin=0.0
    vmax= 1.0
    cbar_name = "Acc"

    ##############################################################################
    ### TRAIN ###

    
    args.n_measure = 3
    args.train_set = True

    fn = f"./data/lc_sin_b_n_s_x_pos{pos}_sl{sl}_trs{args.n_measure}_n{args.n_embd}_d{args.n_layer}_h{args.n_head}_p{args.p}_base{args.base}_ntask{args.n_tasks}_nvar{args.n_var}_dsplit{args.split_data}_ood{args.ood_tasks}_tspl{args.train_set}.pkl"    
            
    with open(fn, 'rb') as fp:
        datas = pickle.load(fp)
    
    mask = np.full((no_tasks, max_cc-1,max_cc-1), np.nan)

    for cc in range(1,max_cc):
        for s in range(cc-1):
            xc,yc,logits, pred, acc_record, loss_record = datas[f"s:{cc}, tr:{step}, loc:{s}"]
            mask[:no_tasks,cc-1,s] = (acc_record[:no_tasks,cc-1])
    
    xp = torch.cat((xc, yc[:,-1].unsqueeze(1)), dim=-1)
    # print(xp)
    # print(xp.shape)
    xp = np.array(xp.view(xp.size(0),-1,args.n_var+1))
    print(xp.shape)

    ##############################################################################
    #### $\mathcal{S}^{\mathrm{i.d.}_{\mathrm{train}}$ ###########################
    
    ax = axs[0]
    ti = 0
    top = mask[ti,:,:].T

    # im0 = sns.heatmap(top, ax=ax, cmap = 'viridis',vmin=vmin,vmax=vmax,cbar=False)
    im0 = sns.heatmap(top, ax=ax, cmap = cmap,vmin=vmin,vmax=vmax,cbar=False, cbar_kws={'ticks': [0, 1]})
    
    ax.set_ylabel("$z'_j$",size=SIZE)
    ax.set_xlabel("Shot $i$",size=SIZE)

    ax.set_title('$\mathcal{S}^{\mathrm{i.d.}}_{\mathrm{train}}$',size=1.5*SIZE)

    ytick_labels = [tuple(xp[ti, idx]) for idx in range(xp.shape[1])]
    # print(xtick_labels)
    ax.set_yticks(np.arange(xp.shape[1],step=4)+1)
    ax.set_yticklabels(np.arange(xp.shape[1],step = 4),rotation=0,size=SIZE)
    ax.set_xticks(np.arange(xp.shape[1],step=4))
    ax.set_xticklabels(np.arange(xp.shape[1],step = 4),rotation=0,size=SIZE)

    ax.invert_yaxis()

    ##############################################################################

    ##############################################################################
    #### $\mathcal{S}^{\mathrm{o.o.d.}_{\mathrm{train}}$ #########################

    ax = axs[1]
    ti = no_tasks//2
    top = mask[ti,:,:].T

    # im0 = sns.heatmap(top, ax=ax, cmap = 'viridis',vmin=vmin,vmax=vmax,cbar=False)
    im0 = sns.heatmap(top, ax=ax, cmap = cmap,vmin=vmin,vmax=vmax,cbar=False, cbar_kws={'ticks': [0, 1]})
    
    ax.set_ylabel("$z'_j$",size=SIZE)
    ax.set_xlabel("Shot $i$",size=SIZE)

    ax.set_title('$\mathcal{S}^{\mathrm{o.o.d.}}_{\mathrm{train}}$',size=1.5*SIZE)

    ytick_labels = [tuple(xp[ti, idx]) for idx in range(xp.shape[1])]
    # print(xtick_labels)
    ax.set_yticks(np.arange(xp.shape[1],step=4)+1)
    ax.set_yticklabels(np.arange(xp.shape[1],step = 4),rotation=0,size=SIZE)
    ax.set_xticks(np.arange(xp.shape[1],step=4))
    ax.set_xticklabels(np.arange(xp.shape[1],step = 4),rotation=0,size=SIZE)

    ax.invert_yaxis()

    ##############################################################################

    ##############################################################################
    ### TEST ###

    args.n_measure = 5
    args.train_set = False
    
    fn = f"./data/lc_sin_b_n_s_x_pos{pos}_sl{sl}_trs{args.n_measure}_n{args.n_embd}_d{args.n_layer}_h{args.n_head}_p{args.p}_base{args.base}_ntask{no_tasks}_nvar{args.n_var}_dsplit{args.split_data}_ood{args.ood_tasks}_tspl{args.train_set}.pkl"    
            
    with open(fn, 'rb') as fp:
        datas = pickle.load(fp)
    
    mask = np.full((no_tasks, max_cc-1,max_cc-1), np.nan)

    for cc in range(1,max_cc):
        for s in range(cc-1):
            xc,yc,logits, pred, acc_record, loss_record = datas[f"s:{cc}, tr:{step}, loc:{s}"]
            mask[:no_tasks,cc-1,s] = (acc_record[:no_tasks,cc-1])
    
    xp = torch.cat((xc, yc[:,-1].unsqueeze(1)), dim=-1)
    # print(xp)
    # print(xp.shape)
    xp = np.array(xp.view(xp.size(0),-1,args.n_var+1))
    print(xp.shape)
    

    ##############################################################################
    #### $\mathcal{S}^{\mathrm{i.d.}_{\mathrm{test}}$ ############################
    
    ax = axs[2]
    ti = 0
    top = mask[ti,:,:].T

    # im0 = sns.heatmap(top, ax=ax, cmap = 'viridis',vmin=vmin,vmax=vmax,cbar=False)
    im0 = sns.heatmap(top, ax=ax, cmap = cmap,vmin=vmin,vmax=vmax,cbar=False, cbar_kws={'ticks': [0, 1]})
    
    ax.set_ylabel("$z'_j$",size=SIZE)
    ax.set_xlabel("Shot $i$",size=SIZE)

    ax.set_title('$\mathcal{S}^{\mathrm{i.d.}}_{\mathrm{test}}$',size=1.5*SIZE)

    ytick_labels = [tuple(xp[ti, idx]) for idx in range(xp.shape[1])]
    # print(xtick_labels)
    ax.set_yticks(np.arange(xp.shape[1],step=4)+1)
    ax.set_yticklabels(np.arange(xp.shape[1],step = 4),rotation=0,size=SIZE)
    ax.set_xticks(np.arange(xp.shape[1],step=4))
    ax.set_xticklabels(np.arange(xp.shape[1],step = 4),rotation=0,size=SIZE)

    ax.invert_yaxis()

    ##############################################################################

    ##############################################################################
    #### $\mathcal{S}^{\mathrm{o.o.d.}_{\mathrm{test}}$ ##########################

    ax = axs[3]
    ti = no_tasks//2
    top = mask[ti,:,:].T

    # im0 = sns.heatmap(top, ax=ax, cmap = 'viridis',vmin=vmin,vmax=vmax,cbar=False)
    im0 = sns.heatmap(top, ax=ax, cmap = cmap,vmin=vmin,vmax=vmax,cbar=False, cbar_kws={'ticks': [0, 1]})
    
    ax.set_ylabel("$z'_j$",size=SIZE)
    ax.set_xlabel("Shot $i$",size=SIZE)

    ax.set_title('$\mathcal{S}^{\mathrm{o.o.d.}}_{\mathrm{test}}$',size=1.5*SIZE)

    xtick_labels = [tuple(xp[ti, idx]) for idx in range(xp.shape[1])]
    # print(xtick_labels)
    ax.set_yticks(np.arange(xp.shape[1],step=4)+1)
    ax.set_yticklabels(np.arange(xp.shape[1],step = 4),rotation=0,size=SIZE)
    ax.set_xticks(np.arange(xp.shape[1],step=4))
    ax.set_xticklabels(np.arange(xp.shape[1],step = 4),rotation=0,size=SIZE)

    ax.invert_yaxis()

    ##############################################################################

    
    cbar = fig.colorbar(axs[0].collections[0], ax=axs.ravel().tolist(), fraction=0.046, pad=0.01, ticks=[0, 1])
    cbar.set_ticks([0.25, 0.75])
    cbar.set_ticklabels(['incorrect', 'correct'], size = SIZE)

    # fig.suptitle(f'{Ws[ic]}, Trial : {step}, Pos: {pos} \n X :{xc[0,:]}')
    if args.savefig is True:
        fig.savefig(f'./label_corr/lc_sin_fin_pos{pos}_maxcc_{max_cc}_tr{step}_n{args.n_embd}_h{args.n_head}_d{args.n_layer}.pdf',format='pdf')
    plt.show()
    plt.close()

print(args.n_tasks)
pos = 2
for step in range(1, 2):
    max_cc = 32
    label_corr_exp3(max_cc,sl=32,pos=pos)
