In [43]:
import itertools as it
import random

import einops
from plotly import express as px, graph_objects as go
import torch as t
import torch.nn.functional as F
from torch.optim.lr_scheduler import LRScheduler, ReduceLROnPlateau
from tqdm import tqdm

from src.contrastive_cnn import ContrastiveCNN
from src.dataset import Dataset, get_diff_and_sim_pair_inds

In [49]:
model = ContrastiveCNN()
ds = Dataset.load()
batches = ds.get_train_batches()
test_diff_pair_inds, test_same_pair_inds = get_diff_and_sim_pair_inds(ds.test_y)

Generating test batches for batch_size=32: 100%|██████████| 1875/1875 [00:00<00:00, 30855.35it/s]


In [118]:
# def contrastive_loss(
#     embeds: t.Tensor,  # [batch d_embed]
#     labels: t.Tensor,  # [batch (d_label=10)]
#     *,
#     epsilon: float = 1e-4
# ) -> t.Tensor:
#     assert embeds.ndim == 2
#     assert labels.ndim == 1
#     assert len(embeds) == len(labels)
#     labels_rep = einops.repeat(labels, "b -> b rep", rep=len(labels))
#     diff_pair_inds = (labels_rep == labels_rep.T).nonzero()
#     same_pair_inds = (labels_rep != labels_rep.T).nonzero()
#     # TODO: figure out how to this optimally, with tensors, einops, gather, etc, rather than loops
#     # TODO: also, better variable names
#     diff_pair_diffs = t.stack([embeds[i] - embeds[j] for i, j in diff_pair_inds])
#     same_pair_diffs = t.stack([embeds[i] - embeds[j] for i, j in same_pair_inds])
#     diff_pair_di
#     loss = (
#         diff_pair_diffs.pow(2).mean()
#         + t.max(t.tensor(0), epsilon - (same_pair_diffs**2).mean().sqrt()) ** 2
#     )
#     assert not loss.isnan()
#     # loss = diff_pair_diffs_squared.mean()
#     return loss

In [50]:
def contrastive_loss(
    embeds: t.Tensor,  # [batch d_embed]
    diff_pair_inds: t.Tensor, # [_n 2]
    same_pair_inds: t.Tensor, # [_n 2]
    *,
    epsilon: float = 1e-4
) -> t.Tensor:
    # TODO: figure out how to this optimally, with tensors, einops, gather, etc, rather than loops
    # TODO: also, better variable names
    diff_pair_diffs = t.stack([embeds[i] - embeds[j] for i, j in diff_pair_inds])
    same_pair_diffs = t.stack([embeds[i] - embeds[j] for i, j in same_pair_inds])
    diff_term_loss = diff_pair_diffs.pow(2).mean()
    same_term_loss = t.max(t.tensor(0), epsilon - (same_pair_diffs**2).mean().sqrt()) ** 2
    loss = diff_term_loss + same_term_loss
    assert not loss.isnan()
    return loss

In [51]:
def acc_fn(
    logits: t.Tensor,
    target: t.Tensor,
) -> float:
    #TODO: need to give preds here
    preds = logits.argmax(-1)
    acc = (preds == target).to(dtype=t.float).mean().item()
    return acc


def train(
    model: ContrastiveCNN,
    ds: Dataset,
    optimizer: t.optim.Optimizer,
    *,
    n_epochs: int = 10,
    batch_size: int = 32,
    verbose: bool = True,
    scheduler: LRScheduler | ReduceLROnPlateau | None = None,
) -> None:  # TrainingHistory:
    batches = ds.get_train_batches(batch_size)
    test_diff_pair_inds, test_same_pair_inds = get_diff_and_sim_pair_inds(ds.train_y)

    for epoch_i in range(1, n_epochs + 1):
        batch_losses: list[float] = []

        for batch_x, batch_y, diff_pair_inds, same_pair_inds in tqdm(batches):
            optimizer.zero_grad()
            batch_embeds = model(batch_x)
            batch_loss = contrastive_loss(batch_embeds, diff_pair_inds, same_pair_inds)
            batch_loss.backward()
            optimizer.step()
            batch_losses.append(batch_loss.item())

        epoch_loss = t.tensor(batch_losses).mean()
        
        # with t.no_grad(): #TODO
            
        
        if isinstance(scheduler, LRScheduler):
            scheduler.step()
        elif isinstance(scheduler, ReduceLROnPlateau):
            scheduler.step(epoch_loss)

        if verbose:
            print(f"[{epoch_i}] {epoch_loss = }")
            if epoch_loss == 0:
                print(
                    f"Achieved zero loss, terminating training after {epoch_i} epochs"
                )
                break

In [52]:
model = ContrastiveCNN()
ds = Dataset.load()
lr = 1e-3
optimizer = t.optim.AdamW(model.parameters(), lr)
scheduler = ReduceLROnPlateau(optimizer)
train(model, ds, optimizer, scheduler=scheduler)


Generating test batches for batch_size=32: 100%|██████████| 1875/1875 [00:00<00:00, 30845.55it/s]


In [124]:
model = ContrastiveCNN()
ds = Dataset.load()
lr = 1e-3
optimizer = t.optim.AdamW(model.parameters(), lr)
scheduler = ReduceLROnPlateau(optimizer)
# train(model, ds, optimizer, scheduler=scheduler)

x = ds.train_x[:10]
y = ds.train_y[:10]
embeds = model(x)
loss = contrastive_loss(embeds, y)
loss

tensor(0.0046, grad_fn=<AddBackward0>)

## Testing

In [60]:
test_embeds = model(ds.test_x)
y = list(range(10))
embed_dists = t.zeros(10, 10, requires_grad=False)

with t.no_grad():
    for label1, label2 in tqdm(it.product(y, repeat=2)):
        embeds1_mean = test_embeds[ds.test_y == label1].mean(0)
        embeds2_mean = test_embeds[ds.test_y == label2].mean(0)
        embed_dist = ((embeds1_mean - embeds2_mean) ** 2).mean()
        embed_dists[label1, label2] = embed_dist

# Normalize so that the highest value is one
embed_dists *= 1 / embed_dists.max()

100it [00:00, 6898.75it/s]


In [61]:
from plotly import express as px, graph_objects as go

fig = px.imshow(embed_dists)
fig.show()

In [83]:
all_train_embeds = model(ds.train_x)
all_test_embeds = model(ds.test_x)

In [115]:
model(ds.train_x[:1])

tensor([[ 0.0048,  0.0122,  0.0070,  0.0141,  0.0066, -0.0043, -0.0110, -0.0130,
          0.0053,  0.0087, -0.0064, -0.0049,  0.0008, -0.0153, -0.0101, -0.0133,
          0.0066,  0.0019, -0.0004, -0.0066, -0.0157, -0.0006, -0.0031, -0.0074,
         -0.0059,  0.0012,  0.0101,  0.0154, -0.0008,  0.0020, -0.0039,  0.0136,
         -0.0120, -0.0157, -0.0066,  0.0029,  0.0040,  0.0157,  0.0086,  0.0022,
         -0.0093, -0.0106, -0.0134, -0.0057,  0.0021,  0.0113, -0.0017, -0.0148,
         -0.0060, -0.0138,  0.0054, -0.0019, -0.0069,  0.0096, -0.0068,  0.0013,
         -0.0032, -0.0052,  0.0035,  0.0094, -0.0027,  0.0036,  0.0142,  0.0054]],
       grad_fn=<AddmmBackward0>)

In [116]:
model(ds.train_x[1:2])

tensor([[ 0.0048,  0.0122,  0.0070,  0.0141,  0.0066, -0.0043, -0.0110, -0.0130,
          0.0053,  0.0087, -0.0064, -0.0049,  0.0008, -0.0153, -0.0101, -0.0133,
          0.0066,  0.0019, -0.0004, -0.0066, -0.0157, -0.0006, -0.0031, -0.0074,
         -0.0059,  0.0012,  0.0101,  0.0154, -0.0008,  0.0020, -0.0039,  0.0136,
         -0.0120, -0.0157, -0.0066,  0.0029,  0.0040,  0.0157,  0.0086,  0.0022,
         -0.0093, -0.0106, -0.0134, -0.0057,  0.0021,  0.0113, -0.0017, -0.0148,
         -0.0060, -0.0138,  0.0054, -0.0019, -0.0069,  0.0096, -0.0068,  0.0013,
         -0.0032, -0.0052,  0.0035,  0.0094, -0.0027,  0.0036,  0.0142,  0.0054]],
       grad_fn=<AddmmBackward0>)

In [111]:
SEED = 42
random.seed(SEED)
t.manual_seed(SEED)

# Take 10 random examples from training set for each label to make label representations
label_embeds = t.zeros(10, model.d_embed)
with t.no_grad():
    for label in y:
        train_x_index = random.sample((ds.train_y == label).nonzero().tolist(), k=10)
        stacked_embeds = t.stack([all_train_embeds[i] for i in train_x_index]).squeeze()
        print(stacked_embeds)

        label_embed = stacked_embeds.mean(0)
        # print(label, label_embed)
        label_embeds[label, :] = label_embed

tensor([[ 0.0048,  0.0122,  0.0070,  0.0141,  0.0066, -0.0043, -0.0110, -0.0130,
          0.0053,  0.0087, -0.0064, -0.0049,  0.0008, -0.0153, -0.0101, -0.0133,
          0.0066,  0.0019, -0.0004, -0.0066, -0.0157, -0.0006, -0.0031, -0.0074,
         -0.0059,  0.0012,  0.0101,  0.0154, -0.0008,  0.0020, -0.0039,  0.0136,
         -0.0120, -0.0157, -0.0066,  0.0029,  0.0040,  0.0157,  0.0086,  0.0022,
         -0.0093, -0.0106, -0.0134, -0.0057,  0.0021,  0.0113, -0.0017, -0.0148,
         -0.0060, -0.0138,  0.0054, -0.0019, -0.0069,  0.0096, -0.0068,  0.0013,
         -0.0032, -0.0052,  0.0035,  0.0094, -0.0027,  0.0036,  0.0142,  0.0054],
        [ 0.0048,  0.0122,  0.0070,  0.0141,  0.0066, -0.0043, -0.0110, -0.0130,
          0.0053,  0.0087, -0.0064, -0.0049,  0.0008, -0.0153, -0.0101, -0.0133,
          0.0066,  0.0019, -0.0004, -0.0066, -0.0157, -0.0006, -0.0031, -0.0074,
         -0.0059,  0.0012,  0.0101,  0.0154, -0.0008,  0.0020, -0.0039,  0.0136,
         -0.0120, -0.0157, 

In [109]:
label_embeds[:, 2]

tensor([0.0070, 0.0070, 0.0070, 0.0070, 0.0070, 0.0070, 0.0070, 0.0070, 0.0070,
        0.0070])

In [103]:
label_embeds[0] == label_embeds[1]

tensor([True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True])

In [99]:
def predict_classes(
    x: t.Tensor,  # [batch 1 28 28]
) -> t.Tensor:  # [batch] (ints of classes 0:9)
    with t.no_grad():
        x_embeds = model(x)
        dists = t.stack(
            [((x_embeds - label_embed) ** 2).mean() for label_embed in label_embeds]
        )
    print(dists)


predict_classes(ds.test_x[:1])

tensor([3.3089e-19, 3.3089e-19, 3.3089e-19, 3.3089e-19, 3.3089e-19, 3.3089e-19,
        3.3089e-19, 3.3089e-19, 3.3089e-19, 3.3089e-19])


In [None]:
# x = ds.train_x[:100]
# y = ds.train_y[:100]
# embeds = model(x)
# loss = contrastive_loss(embeds,y)
# loss.backward()

In [92]:
# for i, j in contrastive_pairs:
#     embed1 = embeds[i]
#     embed2 = embeds[j]
#     print(embed1.shape, embed2.shape)
#     break

torch.Size([64]) torch.Size([64])


In [86]:
# embeds.shape, contrastive_pairs.shape

(torch.Size([100, 64]), torch.Size([1074, 2]))

In [85]:
# t.gather(embeds, dim=0, index=contrastive_pairs)

tensor([[ 0.0207, -0.0079],
        [ 0.0207, -0.1484],
        [ 0.0207, -0.2121],
        ...,
        [ 0.0572, -0.0880],
        [ 0.0572, -0.0749],
        [ 0.0572, -0.1148]], grad_fn=<GatherBackward0>)

In [66]:
# t.tensor(list(it.product(labels_sim_mat[0].nonzero().flatten(), repeat=2)))

tensor([[ 0,  0],
        [ 0, 11],
        [ 0, 35],
        [ 0, 47],
        [ 0, 65],
        [11,  0],
        [11, 11],
        [11, 35],
        [11, 47],
        [11, 65],
        [35,  0],
        [35, 11],
        [35, 35],
        [35, 47],
        [35, 65],
        [47,  0],
        [47, 11],
        [47, 35],
        [47, 47],
        [47, 65],
        [65,  0],
        [65, 11],
        [65, 35],
        [65, 47],
        [65, 65]])