In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm

### **CPC model development**

In [13]:
class CPC(nn.Module):
    def __init__(self,
                 in_features: int,
                 hidden_features: int,
                 slice_length: int = 1024,
                 history_steps: int = 8,
                 drop_rate: float = 0.2,
                 drop_path_rate: float = 0.7):
        super(CPC, self).__init__()

        # define some hyperparameters
        self.slice_length = slice_length
        self.history_steps = history_steps

        # define an encoder
        self.encoder = timm.create_model(
            "resnet18",
            in_chans=in_features,
            num_classes=hidden_features,
            drop_rate=drop_rate,
            drop_path_rate=drop_path_rate,
        )

        # define the autoregressive model
        self.autoregressor = nn.GRU(
            input_size=hidden_features,
            hidden_size=hidden_features,
            num_layers=1,
            batch_first=True
        )

    def forward(self, x: torch.Tensor):
        # x is of shape (B, Ca, C, T)

        # some preprocessing: slice the input into chunks of length slice_length without overlap
        x = x.unfold(3, self.slice_length,
                     self.slice_length).permute(0, 3, 1, 2, 4)  # (B, num_chunks, Ca, C, slice_length)

        # pass the input through the encoder
        B, N, Ca, C, T = x.size()
        x = self.encoder(x.contiguous().view(B*N, Ca, C, T))
        x = x.view(B, N, -1)  # (B, num_chunks, D)

        # pass the history_steps chunks through the autoregressive model
        h0 = torch.zeros(1, x.size(0), x.size(-1)
                         ).to(x.device)  # (num_layers, B, D)

        c, h0 = self.autoregressor(x[:, :self.history_steps])  # (B, history_steps, D) # noqa

        return c[:, -1], x[:, self.history_steps:], h0

In [14]:
def nt_xent_loss(c_t, z_fut, temperature=1.0):
    # c_t is of shape (B, D) and z_fut is of shape (B, K, D)
    # normalize the vectors
    c_t = F.normalize(c_t, p=2, dim=-1)
    z_fut = F.normalize(z_fut, p=2, dim=-1)

    # compute the cosine similarity
    logits = torch.einsum("bd, bkd -> bk", c_t, z_fut) / temperature  # (B, K)

    true_labels = torch.zeros(logits.size(
        0), device=logits.device).long()  # (B,)

    # compute the loss
    return F.cross_entropy(logits, true_labels)

In [15]:
from torchvision.datasets import DatasetFolder
from torch.utils.data import DataLoader
import numpy as np
from tqdm.autonotebook import tqdm
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import os
import yaml

### **Training**

In [None]:
# params={'lr': 0.00026324240451597926, 'weight_decay': 0.0006328245579438751, 'temperature': 0.8613165981723009, 'slice_length': 256, 'history_steps': 8, 'drop_rate': 0.014582702150573057, 'drop_path_rate': 0.3097614472574182}
# params={'lr': 0.004579759039904434, 'weight_decay': 0.00012386582585723483, 'temperature': 0.0966689791227939, 'slice_length': 1536, 'history_steps': 24, 'drop_rate': 0.35419483772460636, 'drop_path_rate': 0.3174014557648441}
config = {
    "in_features": 4,
    "hidden_features": 128,
    "slice_length": 256,
    "history_steps": 8,
    "drop_rate": 0.014582702150573057,
    "drop_path_rate": 0.3097614472574182,
    "temperature": 0.8613165981723009,
    "batch_size": 4,
    "experiment_name": f"cpc_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}",
    "num_epochs": 30,
    "learning_rate": 0.00026324240451597926,
    "weight_decay": 0.0006328245579438751,
    "dataset_path": "./oct10_outdoor_gain_experiments",
    "checkpoint_path": None,
    "comment": "CPC model training",
}

# dump the config
path = os.path.join("./experiments", config["experiment_name"])
os.makedirs(path, exist_ok=True)
with open(os.path.join(path, "config.yaml"), "w") as f:
    yaml.dump(config, f, default_flow_style=False)

In [None]:
class R22_Dataset(DatasetFolder):
    """
    Ensures every sample tensor ends up the same length (the minimum
    length across your entire dataset), by trimming.
    """

    def __init__(self, root, load_fn, transform=None, **kwargs):
        super().__init__(root, loader=load_fn, transform=transform, **kwargs)
        # 1) scan every file once to find the minimum length
        lengths = []
        for path, _ in self.samples:
            # load only the array header, not the entire payload
            arr = np.load(path, mmap_mode='r', allow_pickle=False)
            lengths.append(arr.shape[-1])
        self.min_length = min(lengths)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        path = self.samples[index][0]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        # 2) trim the sample to the minimum length
        sample = sample[..., :self.min_length]
        return sample


def load_fn(path):
    with open(path, "rb") as f:
        data = np.load(f)
        _ = np.load(f, allow_pickle=True).item()
    data = np.stack((data.real, data.imag), axis=1)
    # normalize the data
    data = (data - np.mean(data, axis=(1, 2), keepdims=True)) / \
        (np.std(data, axis=(1, 2), keepdims=True) + 1e-8)
    return data.astype(np.float32)


dataset = R22_Dataset(
    root=config["dataset_path"],
    load_fn=load_fn,
    transform=lambda x: torch.from_numpy(x),
    extensions=[".npy"],
)

dataset[0].shape

In [None]:
dataloader = DataLoader(
    dataset,
    batch_size=config["batch_size"],
    shuffle=True,
    num_workers=4,
    pin_memory=True,
)

xb = next(iter(dataloader))
print(xb.shape)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
model = CPC(
    in_features=config["in_features"],
    hidden_features=config["hidden_features"],
    slice_length=config["slice_length"],
    history_steps=config["history_steps"],
    drop_rate=config["drop_rate"],
    drop_path_rate=config["drop_path_rate"],
).to(device)
c_t, z_fut, _ = model(xb.to(device))
print(c_t.shape, z_fut.shape)
loss = nt_xent_loss(c_t, z_fut)
print(loss)

In [None]:
class Tracker:
    """
    A class to track the best value of a metric.

    :param metric: The name of the metric to track. If 'loss' is in the metric name, the goal is to minimize it.
    :type metric: str
    :param mode: The mode of tracking. Can be 'auto', 'min', or 'max'. Default is 'auto'.
    :type mode: str, optional
    """

    def __init__(self, metric, mode='auto'):
        self.metric = metric
        self.mode = mode
        self.mode_dict = {
            'auto': np.less if 'loss' in metric else np.greater,
            'min': np.less,
            'max': np.greater
        }
        self.operator = self.mode_dict[mode]
        self._best = np.inf if self.operator == np.less else -np.inf

    @property
    def best(self):
        return self._best

    @best.setter
    def best(self, value):
        self._best = value

In [None]:
optim = torch.optim.AdamW(
    model.parameters(),
    lr=config["learning_rate"],
    weight_decay=config["weight_decay"],
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optim,
    T_max=config["num_epochs"],
    eta_min=1e-6,
)

scaler = torch.amp.GradScaler(device=device.type, enabled=True)
writer = SummaryWriter(
    log_dir=os.path.join("./experiments/", config["experiment_name"]),
)
tracker = Tracker("loss/epoch", mode="min")

start_epoch = 0
# resume training from a checkpoint if it exists
checkpoint_path = config["checkpoint_path"]
if checkpoint_path and os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint["state_dict"])
    optim.load_state_dict(checkpoint["opt_state_dict"])
    scheduler.load_state_dict(checkpoint["sch_state_dict"])
    start_epoch = checkpoint["epoch"]
    print(f"Resuming training from epoch {start_epoch}")

step = 0
with tqdm(range(start_epoch, config["num_epochs"])) as master_bar:
    for epoch in master_bar:
        model.train()
        avg_loss = 0.0
        with tqdm(dataloader) as pbar:
            for xb in pbar:
                xb = xb.to(device)
                optim.zero_grad()

                with torch.amp.autocast(device_type=device.type,
                                        enabled=True):
                    c_t, z_fut, _ = model(xb)
                    loss = nt_xent_loss(c_t, z_fut)

                scaler.scale(loss).backward()
                # clip gradients
                scaler.unscale_(optim)
                norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), 1.0, norm_type=2
                )
                scaler.step(optim)
                scaler.update()

                avg_loss += loss.item()

                pbar.set_postfix(
                    {"loss/step": loss.item(), "norm": norm.item()})
                writer.add_scalar("loss/step", loss.item(), step)
                writer.add_scalar("norm/step", norm.item(), step)
                step += 1

        avg_loss /= len(dataloader)
        writer.add_scalar("loss/epoch", avg_loss, epoch)
        master_bar.write(f"Epoch {epoch}: loss = {avg_loss:.4f}")
        scheduler.step()
        writer.add_scalar(
            "learning_rate/epoch",
            optim.param_groups[0]["lr"],
            epoch,
        )
        writer.flush()
        if tracker.operator(avg_loss, tracker.best):
            tracker.best = avg_loss
            # Save the model checkpoint
            checkpoint_path = os.path.join(
                "./experiments", f"{config['experiment_name']}/weights.pth"
            )
            torch.save(model.state_dict(), checkpoint_path)
            print(f"Model saved to {checkpoint_path}")
        # save the latest checkpoint
        checkpoint_path = os.path.join(
            "./experiments", f"{config['experiment_name']}/last_checkpoint.pt")
        torch.save({
            "state_dict": model.state_dict(),
            "opt_state_dict": optim.state_dict(),
            "sch_state_dict": scheduler.state_dict(),
            "epoch": epoch}, checkpoint_path)
        print(
            f"Latest checkpoint saved to {checkpoint_path} at epoch {epoch}"
        )

writer.close()
print("Training complete.")

##### **Notes**

- training has very dramatic gradient that dances around and loss is very noisy
- use of gradient clipping and data normalization per sample (try whole batch normalization too) reduced the noisy loss but more improvements can be done
- TODO: introduce the silhoutte score to determining if the clustering is improving during training

### **Inspection**

In [None]:
# load the best model
# plot the TSNE and UMAP of the features
# plot pseudo-labels using KMeans and compare with the ground truth
encoder_model = timm.create_model(
    "resnet18",
    in_chans=config["in_features"],
    num_classes=config["hidden_features"],
    drop_rate=config["drop_rate"],
    drop_path_rate=config["drop_path_rate"],
).to(device)

weights = torch.load(
    "./experiments/cpc_2025-05-05_01-16-11/weights.pth", weights_only=True, map_location=device
)
weights = {k.replace("encoder.", ""): v for k, v in weights.items()}
msg = encoder_model.load_state_dict(
    weights, strict=False)

assert len(msg.missing_keys) == 0, f"Missing keys: {msg.missing_keys}"