In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import fer.data as fecdata
from pathlib import Path
import torch.nn.functional as F
device = 'cuda:1'

In [3]:
df = fecdata.pac_to_pac_transactions()
dataset, df, labelers = fecdata.prepare(df)

In [4]:
from fer.model import Config, FECEncoder, TabDataset, TabularDenoiser
import torch

cfg = Config(
    embedding_init_std=1e-4,
    tied_encoder_decoder_emb=True,
)
lr = 5e-4
n_epochs = 4
model = TabularDenoiser(
    cfg,
    n_entities=max(dataset["src"].max(), dataset["dst"].max()) + 1,
    n_etype=dataset["etype"].max() + 1,
    n_ttype=dataset["ttype"].max() + 1,
)
tds = TabDataset(dataset)

Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda


In [5]:
from torch.utils.data import DataLoader, random_split
model = model.to(device)
model = torch.compile(model)

In [6]:
splitgen = torch.Generator().manual_seed(41)
batch_size=3000    
train_set, val_set = random_split(tds, [0.9, 0.1], generator=splitgen)
tdl = DataLoader(
    train_set,
    batch_size=batch_size,
    shuffle=True,
    num_workers=8,
    #persistent_workers=True,
)
vdl = DataLoader(
    val_set,
    batch_size=batch_size,
    shuffle=False,
    num_workers=8,
    #persistent_workers=True,
)

In [7]:
import torch.optim.lr_scheduler as lrsched
import math

optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=lr,
)
class WarmupConstantSchedule(lrsched.LambdaLR):
    """ Linear warmup and then constant.
        Linearly increases learning rate schedule from 0 to 1 over `warmup_steps` training steps.
        Keeps learning rate schedule equal to 1. after warmup_steps.
    """
    def __init__(self, optimizer, warmup_steps, last_epoch=-1):
        self.warmup_steps = warmup_steps
        super(WarmupConstantSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)

    def lr_lambda(self, step):
        if step < self.warmup_steps:
            return float(step) / float(max(1.0, self.warmup_steps))
        return 1.

class WarmupCosineSchedule(lrsched.LambdaLR):
    """ Linear warmup and then cosine decay.
        Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps.
        Decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps following a cosine curve.
        If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup.
    """
    def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1):
        self.warmup_steps = warmup_steps
        self.t_total = t_total
        self.cycles = cycles
        super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)

    def lr_lambda(self, step):
        if step < self.warmup_steps:
            return float(step) / float(max(1.0, self.warmup_steps))
        # progress after warmup
        progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps))
        return max(0.0, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress)))
scheduler = WarmupCosineSchedule(optimizer, 1000, t_total=len(tdl) * n_epochs)

In [8]:
from fer.multitask import CoVWeightingLoss, UncertaintyWeightedLoss
n_losses = 15
# lossweighter = CoVWeightingLoss(n_losses)
lossweighter = UncertaintyWeightedLoss(n_losses)

In [9]:
from tqdm.notebook import tqdm
import wandb
from dataclasses import asdict

In [10]:
dtsks = sorted(k for k in dataset.keys() if k.startswith('scaled_dt_'))
def decoder_loss(encoded, batch):
    srclogits, dstlogits, etlogits, ttlogits, amtd, amtpos, dt_pred = model.decoder(encoded, model.encoder)
    srcloss = F.cross_entropy(srclogits, batch['src'].squeeze())
    dstloss = F.cross_entropy(dstlogits, batch['dst'].squeeze())
    etloss = F.cross_entropy(etlogits, batch['etype'].squeeze())
    ttloss = F.cross_entropy(ttlogits, batch['ttype'].squeeze())
    amtloss = F.mse_loss(amtd, batch['amt'])
    amtposloss = F.binary_cross_entropy_with_logits(amtpos, batch['amt_pos'].to(torch.float))
    #print(dt_pred.shape)
    dt_targets = torch.cat([batch[k].squeeze(dim=1) for k in dtsks], dim=1)
    #print(dt_targets.shape)
    dt_loss = F.mse_loss(dt_pred, dt_targets) 
    return dict(srcloss=srcloss,dstloss=dstloss,etloss=etloss,ttloss=ttloss,amtloss=amtloss,amtposloss=amtposloss,dt_loss=dt_loss)

In [11]:
wandb.init(project='fecentrep2', config=dict(lr=lr, **asdict(cfg)))
for epoch in range(n_epochs):
    with tqdm(tdl) as t:
        for i, batch in enumerate(t):
            batch = {k:v.to(device) for k,v in batch.items()}
            model.zero_grad()
            orig, corrupted, recovered = model(batch)
            enclosses = decoder_loss(orig, batch)
            reclosses = decoder_loss(recovered, batch)
            distloss = F.mse_loss(orig, recovered)
            # margin = 0.1
            # ocdiff = (orig != corrupted).max(dim=2).values.max(dim=0).values.float()
            # rec_corrupt_err = ((recovered-corrupted).pow(2).mean(dim=2).mean(dim=0) * ocdiff).sum() / ocdiff.sum()
            # repel_loss = F.relu(margin - rec_corrupt_err)
            all_losses = {}
            all_losses.update({f'enc/{k}': v for k,v in enclosses.items()})
            all_losses.update({f'rec/{k}': v for k,v in reclosses.items()})
            all_losses['dist_loss'] = distloss
            # all_losses['repel_loss'] = repel_loss
            weighted_loss = lossweighter.forward([lv for _, lv in sorted(all_losses.items())])
            total_loss = weighted_loss
            all_losses['total_loss'] = total_loss
            wandb.log(dict(**all_losses, lr=scheduler.get_last_lr()[0]))
            total_loss.backward()
            t.set_postfix(dict(loss=str(total_loss)))
            optimizer.step()
            scheduler.step()

[34m[1mwandb[0m: Currently logged in as: [33mapage43[0m. Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/1057 [00:00<?, ?it/s]



  0%|          | 0/1057 [00:00<?, ?it/s]

  0%|          | 0/1057 [00:00<?, ?it/s]

  0%|          | 0/1057 [00:00<?, ?it/s]

In [12]:
wandb.finish()

VBox(children=(Label(value='0.007 MB of 0.015 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.485556…

0,1
dist_loss,█████▇▆▆▅▄▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
enc/amtloss,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
enc/amtposloss,█▇▅▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
enc/dstloss,███████▇▇▆▅▄▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
enc/dt_loss,█▅▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
enc/etloss,████▇▇▅▄▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
enc/srcloss,██████▇▇▇▆▅▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
enc/ttloss,█████▇▄▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr,▁▂▃▄▄▅▆▆▇██████▇▇▇▇▆▆▆▅▅▄▄▄▃▃▃▂▂▂▂▂▁▁▁▁▁
rec/amtloss,█▆▅▃▃▂▃▂▂▂▂▂▁▂▁▂▂▂▂▁▁▂▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▂

0,1
dist_loss,0.0418
enc/amtloss,0.0
enc/amtposloss,0.00134
enc/dstloss,0.33926
enc/dt_loss,0.00024
enc/etloss,0.00343
enc/srcloss,0.56437
enc/ttloss,0.00753
lr,0.0
rec/amtloss,0.07157


In [13]:
from umap import UMAP
import umap.plot as upl

  @numba.jit()
  @numba.jit()
  @numba.jit()
  @numba.jit()
  @numba.jit()


In [14]:
import holoviews as hv
hv.extension('bokeh')

In [15]:
entemb = model.encoder.entity_embeddings.weight.detach().cpu().numpy()
entemb.shape

(17567, 256)

In [16]:
uop = UMAP(verbose=True, min_dist=0.01)

In [17]:
e2d = uop.fit_transform(entemb)

UMAP(min_dist=0.01, verbose=True)
Sat Jul 29 21:39:30 2023 Construct fuzzy simplicial set
Sat Jul 29 21:39:30 2023 Finding Nearest Neighbors
Sat Jul 29 21:39:30 2023 Building RP forest with 12 trees
Sat Jul 29 21:39:34 2023 NN descent for 14 iterations
	 1  /  14
	 2  /  14
	 3  /  14
	 4  /  14
	 5  /  14
	 6  /  14
	Stopping threshold met -- exiting after 6 iterations
Sat Jul 29 21:39:55 2023 Finished Nearest Neighbor Search
Sat Jul 29 21:39:58 2023 Construct embedding


Epochs completed:   0%|            0/200 [00:00]

Sat Jul 29 21:40:10 2023 Finished embedding


In [18]:
import pandas as pd
id2cid = labelers['id_labeler'].encoder.classes_
idorder = pd.DataFrame({'CMTE_ID':id2cid})
eframe = pd.DataFrame(e2d, columns=['x', 'y'])

In [19]:
def read_frame(header_file, data_file, dtypes={}):
    header = pd.read_csv(header_file)
    dt = {c: str for c in header.columns}
    dt.update(dtypes)
    data = pd.read_csv(data_file, sep="|", names=header.columns, dtype=dt)
    return data

def read_cm(year, basedir='./data'):
    cm = read_frame(
        f"{basedir}/cm_header_file.csv",
        f"{basedir}/{year}/cm.txt",
        dtypes={
            c: "str"
            for c in (
                "CMTE_DSGN",
                "CMTE_TP",
                "CMTE_PTY_AFFILIATION",
                "CMTE_FILING_FREQ",
            )
        },
    )
    return cm

cmdf = idorder.join(pd.concat([read_cm(2020), read_cm(2022), read_cm(2024)]).drop_duplicates(subset=['CMTE_ID'], keep='last').set_index('CMTE_ID'), on='CMTE_ID')

In [20]:
sz=450
(hv.Points(eframe.join(cmdf)).opts(width=sz, height=sz, color='CMTE_PTY_AFFILIATION', cmap='Category20') + 
 hv.Points(eframe.join(cmdf)).opts(width=sz, height=sz, color='CMTE_DSGN', cmap='Category20') + 
 hv.Points(eframe.join(cmdf)).opts(width=sz, height=sz, color='CMTE_TP', cmap='Category20') +
 hv.Points(eframe.join(cmdf)).opts(width=sz, height=sz, color='ORG_TP', cmap='Category20')).cols(2)

In [21]:
def do_atlas(do_norm=True):
    from nomic import atlas
    from sklearn.preprocessing import normalize
    
    do_norm = True
    atlas.map_embeddings(
        normalize(entemb) if do_norm else entemb,
        data=eframe.join(cmdf).drop(columns=['x','y']),
        name='fecentrep-2' + ('-norm' if do_norm else ''),
        colorable_fields=['CMTE_TP', 'CMTE_DSGN', 'ORG_TP', 'CMTE_PTY_AFFILIATION'],
        id_field='CMTE_ID',
        topic_label_field='CMTE_NM',
        reset_project_if_exists=True,
    )

In [None]:
# do_atlas()