# Attention Transformer NN for King Heritage Data

## Run 15

`Test_30*` - `Test_31*`

Sanity check training, where we are directly training a seq2seq model where input is same as Train_14, and output is one-hot encoded classes of which population/superpopulation the given indices fall into.


## Setup

In [1]:
#!conda activate jupyter_env
#!pip install -r "../requirements.txt"
# !pip install gputil

In [60]:
## Import meta setup

# In order to force reload any changes done to the models package files
%load_ext autoreload
%autoreload 2

# Allow import from our custom lib python files
import sys
import os

module_path = os.path.abspath(os.path.join('../'))
# module_path = os.path.abspath(os.path.join('../src/'))
if module_path not in sys.path:
    sys.path.append(module_path)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Fix for working in TLJH context

Issue with multithreading spawning with the jupyter hub setup (The Littlest Jupyter Hub on Paperspace) and working with torch, and with the dataloader multithread loading. See the following for issue discussion and solution: 

https://github.com/pytorch/pytorch/issues/40403#issuecomment-1704178443

In [3]:
import torch.multiprocessing as mp 
mp.set_start_method('spawn')

In [61]:
import os
import json


from lib.params import * # device, use_cuda, Checkpoint, various saving strs
from lib.datasets import TokenizedPopDataset, TokenizedCollateFn
from lib.models import TokenizedPopTransformer
from lib.saveload import *
from lib.training import train_model_tokenized, tokenized_masked_loss
import lib.notebook_utils as custom_info

import dill
import pandas as pd
import torch
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader, random_split, Subset
import numpy as np

### Debug Machine Info

In [5]:
custom_info.print_python_info()
custom_info.print_imports(globals())
custom_info.print_machine_info()

Current Python executable: /opt/tljh/user/bin/python	3.10.10 3.10.10 | packaged by conda-forge | (main, Mar 24 2023, 20:08:06) [GCC 11.3.0]
Current Directory: /home/jupyter-bhavana/Notebooks
sys==Python BuiltIn
torch==2.1.0
pandas==2.0.0
numpy==1.24.4
os==unknown
json==unknown
datetime==unknown
lib==unknown
System: Linux
Node Name: psapmaq8sxf1
Release: 5.15.0-124-generic
Version: #134-Ubuntu SMP Fri Sep 27 20:20:17 UTC 2024
Machine: x86_64
Processor: x86_64
Physical cores: 8
Total cores: 8
Max Frequency: 0.00Mhz
Min Frequency: 0.00Mhz
Current Frequency: 3202.58Mhz
CPU Usage Per Core:
Core 0: 0.0%
Core 1: 0.0%
Core 2: 0.0%
Core 3: 0.0%
Core 4: 0.0%
Core 5: 0.0%
Core 6: 0.0%
Core 7: 0.0%
Total CPU Usage: 0.7%
Total: 44.07GB
Available: 42.53GB
Used: 1.07GB
Percentage: 3.5%
Total: 0.00B
Free: 0.00B
Used: 0.00B
Percentage: 0.0%
Partitions and Usage:
=== Device: /dev/mapper/ubuntu--vg-root ===
  Mountpoint: /
  File system type: ext4
  Total Size: 982.45GB
  Used: 170.79GB
  Free: 771.63GB


## Global parameters

In [3]:
runname = "test15"
machine = "Paperspace"
datapath = "../Data/king_matrix.csv"
outdir = os.path.join("../Output/Runs/", runname)
tensorboard_dir = "../Output/Tensorboard"
SEED = 42

if not os.path.exists(outdir):
        os.mkdir(outdir)
if not os.path.exists(tensorboard_dir):
    os.mkdir(tensorboard_dir)

print(f"Using {device} device")

Using cpu device


## Load Data


In [16]:
# The official 1000 genome sample names to popcodes
sample_to_popcode = pd.read_csv("../Data/igsr_samples.tsv", sep="\t")[["Sample name", "Population code", "Superpopulation code", "Superpopulation name"]].dropna()
pop_to_superpop = sample_to_popcode.set_index("Population code").to_dict()["Superpopulation code"]

# Label int to popcode
king_popcodes = pd.read_csv("../Output/Heritage_UMAP/variable_to_integer_conversion_tribe_string_labels.csv", index_col=0)

# Labels as ints
y_ints = pd.read_csv("../Output/Heritage_UMAP/labels_file_series.csv", index_col=0)["0"]
# Labels as popcodes
y_codes = y_ints.replace(king_popcodes.set_index("0").to_dict()["1"])
y_super_codes = y_codes.replace(pop_to_superpop)
y_super_ints = y_super_codes.astype("category").cat.codes

In [62]:
dsize = 20000
vstart = int(dsize * 0.8)
maxseqlen = 100
maxind = 2502
# padval = -1
batchsize = 50

dataset = TokenizedPopDataset(datapath, y_ints, dsize=dsize, maxseqlen=maxseqlen, maxind=maxind)
dataset_superpop = TokenizedPopDataset(datapath, y_super_ints, dsize=dsize, maxseqlen=maxseqlen, maxind=maxind)

# Don't need to random sample to make subsets since they're already random lens of random inds, and want even distribution
# of representative indices in both train/test
train, test = Subset(dataset, range(vstart)), Subset(dataset, range(vstart, dsize))
train_superpop, test_superpop = Subset(dataset_superpop, range(vstart)), Subset(dataset_superpop, range(vstart, dsize))

dl_args = dict(batch_size=batchsize, shuffle=True, num_workers=6)
fn=TokenizedCollateFn(dataset.padind, dataset.padval).collate_fn
fn_superpop=TokenizedCollateFn(dataset_superpop.padind, dataset_superpop.padval).collate_fn
train_dataloader, test_dataloader = DataLoader(train, **dl_args, collate_fn=fn), DataLoader(test, **dl_args, collate_fn=fn)
train_superpop_dataloader, test_superpop_dataloader = DataLoader(train_superpop, **dl_args, collate_fn=fn_superpop)\
                                                        ,DataLoader(test_superpop, **dl_args, collate_fn=fn_superpop)


print(len(dataset), len(dataset[-1][0]), dataset[0][0].shape, dataset[0][1].shape)
print(len(train_superpop), len(train_superpop[0][0]), train_superpop[2501][0].shape, train_superpop[2501])
print(next(iter(test_superpop_dataloader)))

20000 3 torch.Size([2]) torch.Size([2])
16000 2 torch.Size([28]) (tensor([ 147, 1457,  367,  751, 1486, 2499, 1700,  762, 2068, 1403,  584,   50,
        1216, 1031,  567, 1279,  594,  195, 1650, 2437, 1154,   73,  485, 2071,
        1787, 1873, 1546,  611], dtype=torch.int32), tensor([3, 4, 1, 2, 4, 4, 3, 0, 0, 4, 3, 3, 0, 0, 3, 0, 3, 2, 4, 4, 0, 3, 1, 0,
        2, 2, 4, 3]))
(tensor([[ 553, 1026, 2128,  ..., 2503, 2503, 2503],
        [1757, 1500, 1345,  ..., 2503, 2503, 2503],
        [1667, 1454, 1957,  ..., 2503, 2503, 2503],
        ...,
        [1501,  875,    7,  ..., 2503, 2503, 2503],
        [ 329, 2294, 1875,  ..., 2503, 2503, 2503],
        [1905, 2022, 1817,  ..., 2503, 2503, 2503]], dtype=torch.int32), tensor([[3, 4, 0,  ..., 5, 5, 5],
        [0, 4, 0,  ..., 5, 5, 5],
        [3, 4, 2,  ..., 5, 5, 5],
        ...,
        [4, 1, 3,  ..., 5, 5, 5],
        [1, 3, 2,  ..., 5, 5, 5],
        [0, 2, 2,  ..., 5, 5, 5]]), tensor([[False, False, False,  ...,  True,  True,  Tr

## Create Model(s)

30. Train on pop labels
31. Train on superpop labels

In [46]:
# Using random amt a bit greater than maxseqlen; and d_model = embed_size
d_model = 120

model_names = ["Test_30_pop_2h", "Test_31_superpop_2h"]
base_params = dict(d_model=d_model,
                    num_encoder_layers=3,
                    num_decoder_layers=2,
                    dim_feedforward=512,
                    activation=nn.Tanh(),
                    use_pe=False,
                    dropout_pe=0.0,
                    maxseqlen=maxseqlen, 
                    maxind=maxind,
                   num_head=2
                  )

run_details = {"run_params": dict(
                    machine=machine,
                    epochs = 80,
                    checkpoint_at = 20,
                    load=False,
                    batch_pr=int(dsize / batchsize / 5), # Print/validate every 1/5 of epoch
                    runname=runname
                    ),
                model_names[0]: dict(
                    name=model_names[0],
                    num_classes=y_ints.max() + 1
                    ) | base_params,
                model_names[1]: dict(
                    name=model_names[1],
                    num_classes=y_super_ints.max() + 1
                    ) | base_params,
                }
models = [TokenizedPopTransformer(**run_details[m]).to(device) for m in model_names]

assert models[1].padind == dataset_superpop.padind

print(models)

# Save details
with open(os.path.join(outdir, f"details_{runname}.json"), "w" ) as write:
    json.dump(run_details, write, indent=2, default=lambda x: f"nn.{x.__class__.__name__}")

[TokenizedPopTransformer(
  (pos_encoding): Identity()
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-2): 3 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=120, out_features=120, bias=True)
        )
        (linear1): Linear(in_features=120, out_features=512, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=512, out_features=120, bias=True)
        (norm1): LayerNorm((120,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((120,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
        (activation): Tanh()
      )
    )
    (norm): LayerNorm((120,), eps=1e-05, elementwise_affine=True)
  )
  (decoder): TransformerDecoder(
    (layers): ModuleList(
      (0-1): 2 x TransformerDecoderLayer(
        (self_attn): MultiheadAttention(
   

### Test model

In [47]:
# tstdata = next(iter(train_superpop_dataloader))
# f = models[0](tstdata[0], tstdata[2])
# print(f.shape, f)

torch.Size([50, 26, 98]) tensor([[[ 4.9694e-01,  8.9748e-01, -3.9411e-01,  ..., -3.4695e-01,
          -1.2233e-01, -9.7160e-02],
         [ 3.0515e-01,  3.4287e-01, -6.6585e-01,  ...,  1.4235e-01,
           3.1590e-06,  3.3469e-01],
         [ 8.4043e-01,  4.5042e-01,  7.8379e-01,  ..., -3.6689e-01,
          -2.9166e-01,  8.2611e-02],
         ...,
         [-1.2249e+00,  3.9670e-01, -3.5684e-01,  ..., -2.9317e-01,
          -8.7460e-01, -6.2133e-01],
         [-7.5833e-01, -9.0963e-01, -7.5382e-01,  ...,  1.3078e-01,
          -2.0079e-01,  6.4995e-02],
         [-1.0702e-01, -5.0885e-01, -5.0882e-01,  ...,  3.9217e-01,
           7.0239e-01,  7.3099e-01]],

        [[-2.2034e-01,  1.3193e-01, -7.5610e-02,  ..., -8.9451e-01,
          -6.1028e-01, -7.8862e-01],
         [-3.1998e-01,  5.5012e-02,  3.2584e-01,  ...,  7.3531e-01,
           7.8946e-01,  1.0029e+00],
         [ 1.9118e-02,  6.1852e-01,  2.9956e-01,  ...,  6.3022e-01,
           2.9075e-01,  8.2590e-01],
         ...,


## Train the Model(s)

In [63]:
# %%capture cap --no-stderr

loss_fcn = nn.CrossEntropyLoss(ignore_index=dataset.padval)
model = models[0]

writer = SummaryWriter(os.path.join(tensorboard_dir, f'{machine}_{model.get_name()}_{runname}'))
# Set foreach=False to avoid OOM
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, foreach=False)

train_model_tokenized(model=model,
            optimizer=optimizer,
            train_data=train_dataloader,
            validate_data=test_dataloader,
            loss_fcn=loss_fcn,
            padval=None, # No padval used with loss fn
            output_run_dir=outdir,
            **run_details["run_params"],
            writer=writer,
            output_onnx=False
        )


writer.close()


loss_fcn = nn.CrossEntropyLoss(ignore_index=dataset_superpop.padval)
model = models[1]

writer = SummaryWriter(os.path.join(tensorboard_dir, f'{machine}_{model.get_name()}_{runname}'))
# Set foreach=False to avoid OOM
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, foreach=False)

train_model_tokenized(model=model,
            optimizer=optimizer,
            train_data=train_superpop_dataloader,
            validate_data=test_superpop_dataloader,
            loss_fcn=loss_fcn,
            padval=None, # No padval used with loss fn
            output_run_dir=outdir,
            **run_details["run_params"],
            writer=writer,
            output_onnx=False
        )


writer.close()


Training Test_30_pop_2h


  return torch._native_multi_head_attention(


[0, 80] loss: 3.2303109407424926, validation loss: 3.0945286095142364, average train time (sec): 0.014232998700754252
[0, 160] loss: 2.947646087408066, validation loss: 2.756226372718811, average train time (sec): 0.0104105636375607
[0, 240] loss: 2.603339359164238, validation loss: 2.387198027968407, average train time (sec): 0.006780774737126194
[0, 320] loss: 2.2214962035417556, validation loss: 1.9376036271452903, average train time (sec): 0.0073950860125478355
[1, 80] loss: 1.72471232265234, validation loss: 1.2814940422773362, average train time (sec): 0.00881681335013127


KeyboardInterrupt: 