In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
import copy

  from .autonotebook import tqdm as notebook_tqdm


### **SimCLR model development**

In [2]:
# adapted from https://github.com/sthalles/SimCLR.git & https://github.com/lightly-ai/lightly/blob/master/lightly/loss/ntx_ent_loss.py#L17
class SimCLR(nn.Module):
    def __init__(self, n_chans,
                 dim=128,
                 mlp_dim=2048,
                 temperature=0.1,
                 drop_rate=0.1,
                 drop_path_rate=0.1):
        super(SimCLR, self).__init__()
        self.temperature = temperature
        self.backbone = timm.create_model(
            "resnet18",
            in_chans=n_chans,
            num_classes=mlp_dim,
            drop_rate=drop_rate,
            drop_path_rate=drop_path_rate,
        )

        hidden_dim = self.backbone.fc.weight.shape[1]
        del self.backbone.fc
        self.backbone.fc = self._build_mlp(3, hidden_dim, mlp_dim, dim, False)

    def _build_mlp(self, num_layers, input_dim, mlp_dim, output_dim, last_bn=True):
        mlp = []
        for l in range(num_layers):
            dim1 = input_dim if l == 0 else mlp_dim
            dim2 = output_dim if l == num_layers - 1 else mlp_dim

            mlp.append(nn.Linear(dim1, dim2, bias=False))

            if l < num_layers - 1:
                mlp.append(nn.BatchNorm1d(dim2))
                mlp.append(nn.ReLU(inplace=True))
            elif last_bn:
                # follow SimCLR's design: https://github.com/google-research/simclr/blob/master/model_util.py#L157
                # for simplicity, we further removed gamma in BN
                mlp.append(nn.BatchNorm1d(dim2, affine=False))

        return nn.Sequential(*mlp)

    def info_nce_loss(self, z_i: torch.Tensor, z_j: torch.Tensor, temperature=0.1):
        """
        :param z_i: Feature vectors from the first view (batch_size, feature_dim)
        :type z_i: torch.Tensor
        :param z_j: Feature vectors from the second view (batch_size, feature_dim)
        :type z_j: torch.Tensor
        :param temperature: Temperature parameter for scaling the similarity
        :type temperature: float
        """
        # normalize the feature vectors
        z_i = F.normalize(z_i, dim=-1)
        z_j = F.normalize(z_j, dim=-1)

        logits_00 = torch.einsum('ik,jk->ij', z_i, z_i) / temperature
        logits_01 = torch.einsum('ik,jk->ij', z_i, z_j) / temperature
        logits_10 = torch.einsum('ik,jk->ij', z_j, z_i) / temperature
        logits_11 = torch.einsum('ik,jk->ij', z_j, z_j) / temperature

        B, _ = z_i.shape

        mask = torch.eye(B, dtype=torch.bool).to(z_i.device)

        # remove the diagonals of the self logits matrices
        logits_00 = logits_00[~mask].view(B, -1)
        logits_11 = logits_11[~mask].view(B, -1)

        # concatenate the logits
        logits_0100 = torch.cat([logits_01, logits_00], dim=-1)
        logits_1011 = torch.cat([logits_10, logits_11], dim=-1)

        logits = torch.cat([logits_0100, logits_1011], dim=0)
        labels = torch.arange(B, dtype=torch.long).repeat(2).to(z_i.device)

        return F.cross_entropy(logits, labels, reduction='mean')

    def forward(self, x0: torch.Tensor, x1: torch.Tensor):
        feats = self.backbone(torch.cat([x0, x1], dim=0))
        z0, z1 = feats.chunk(2)

        return self.info_nce_loss(z0, z1, self.temperature)

### **Training**

In [3]:
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Compose, Lambda
import h5py
import numpy as np
from tqdm.autonotebook import tqdm
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import os
import yaml

In [4]:
config = {
    "n_chans": 4,
    "dim": 128,
    "mlp_dim": 2048,
    "drop_rate": 0.0,
    "drop_path_rate": 0.0,
    "temperature": 0.1,
    "batch_size": 1024,
    "experiment_name": f"simclr_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}",
    "num_epochs": 100,
    "knn_k": 200,
    "knn_t": 0.1,
    "learning_rate": 0.03,
    "weight_decay": 1e-4,
    "dataset_path": "./datasets",
    "checkpoint_path": None,
    "comment": "SimCLR model training",
}

In [5]:
class R22_Dataset(Dataset):
    """
    Avoid loading the entire dataset into memory.
    """

    def __init__(self, h5_path,
                 input_label="iq_data",
                 target_label="angles",
                 txfms=None):
        self.h5_path = h5_path
        self.input_label = input_label
        self.target_label = target_label
        self.txfms = txfms
        self._file = None

    def _ensure_open(self):
        if self._file is None:
            # open in read-only, latest libver, SWMR if you know no writers
            self._file = h5py.File(
                self.h5_path, 'r', libver='latest', swmr=True)
            self._x = self._file[self.input_label]
            self._y = self._file[self.target_label]

    def __len__(self):
        self._ensure_open()
        return self._x.shape[0]

    def __getitem__(self, idx):
        self._ensure_open()
        x = self._x[idx]
        y = self._y[idx]
        if self.txfms:
            x = self.txfms(x)
        return x, y

    def __del__(self):
        if self._file is not None:
            self._file.close()

In [6]:
class RandomZeroMasking:
    """
    https://arxiv.org/pdf/2207.03046
    :param max_rate: Maximum rate of zero masking.
    :type max_rate: float (default = 0.1)
    :param dim: Dimension to apply zero masking.
    :type dim: int (default = -1)
    :rtype: torch.Tensor
    """

    def __init__(self, max_rate=.1, dim=-1):
        self.max_rate = max_rate
        self.dim = dim

    def __call__(self, x: torch.Tensor):
        mask_size = torch.randint(
            0, int(x.size(self.dim)*self.max_rate) + 1, (1,)).item()
        mask = torch.ones_like(x).to(x.device)
        mask_idx = [slice(None)] * x.ndim
        rand_start_idxs = torch.randint(
            0, x.size(self.dim), (max(min((x.size(self.dim) - mask_size) // 5, 5), 1),))
        mask_idx[self.dim] = torch.flatten(rand_start_idxs.unsqueeze(1)
                                           + torch.arange(mask_size).unsqueeze(0))
        mask_idx[self.dim] = mask_idx[self.dim].clamp(0, x.size(self.dim) - 1)
        mask[mask_idx] = 0
        return x * mask


class RandomAntennaDropout:
    """
    https://arxiv.org/pdf/2312.04519
    :param rate: Dropout rate.
    :type rate: float (default = 0.1)
    :param arrangement: Channel arrangement. 'interleaved' or 'grouped'
    :type arrangement: str (default = 'grouped')
    :rtype: torch.Tensor
    """

    def __init__(self, rate=0.1, arrangement='grouped'):
        self.rate = rate
        self.arrangement = arrangement

    def __call__(self, x: torch.Tensor):
        x = x.clone()  # avoid modifying the input tensor
        if x.ndim == 3:
            # Handling x of shape [4, 2, 4096]
            mask = (torch.rand(*x.shape[:-2], 1, 1) > self.rate).float()
            # Apply mask independently to each [2, 4096] pair
            x = x * mask.to(x.device)
            return x

        num_pairs = x.size(-2) // 2

        if self.arrangement == 'interleaved':
            mask = (torch.rand(*x.shape[:-2], num_pairs, 1) >
                    self.rate).repeat_interleave(2, dim=1)
            x = x * mask.to(x.device)

        elif self.arrangement == 'grouped':
            mask = (torch.rand(*x.shape[:-2],
                    num_pairs, 1) > self.rate).float()
            x[:num_pairs, :] *= mask.to(x.device)
            x[num_pairs:, :] *= mask

        return x


class RandomCircularShift:
    """
    :param max_shift: Maximum shift as a fraction of the sequence length (1.0 means full length).
    :type max_shift: float (default is 1.0)
    """

    def __init__(self, max_shift=1.0):
        self.max_shift = max_shift

    def __call__(self, x: torch.Tensor):
        T = x.size(-1)
        m = int(T * self.max_shift)
        shift = torch.randint(-m, m + 1, (1,)).item()
        return torch.roll(x, shifts=shift, dims=-1)


class Jitter:
    """
    https://arxiv.org/pdf/2007.15951
    :param var: Variance of the Gaussian noise.
    :type var: float (default = 1e-5)
    """

    def __init__(self, var=1e-5):
        self.var = var

    def __call__(self, x: torch.Tensor):
        noise = torch.randn_like(x).to(x.device) * self.var
        return x + noise

In [7]:
class SimCLRTransform:
    """Take two random transform of one data"""

    def __init__(self, base_transform1, base_transform2):
        self.base_transform1 = base_transform1
        self.base_transform2 = base_transform2

    def __call__(self, x):
        x1 = self.base_transform1(x)
        x2 = self.base_transform2(x)
        return [x1, x2]

In [8]:
aug1 = Compose([
    Lambda(lambda x: torch.tensor(x, dtype=torch.float32)),
    RandomZeroMasking(max_rate=0.2),
    RandomAntennaDropout(rate=0.1),
    RandomCircularShift(max_shift=0.3),
    Jitter(var=1e-5),
])

aug2 = Compose([
    Lambda(lambda x: torch.tensor(x, dtype=torch.float32)),
    RandomZeroMasking(max_rate=0.2),
    RandomAntennaDropout(rate=0.1),
    RandomCircularShift(max_shift=0.3),
    Jitter(var=1e-5),
])

train_ds = R22_Dataset(
    os.path.join(config["dataset_path"], "train_preprocessed.h5"),
    input_label="iq_data",
    target_label="angles",
    txfms=SimCLRTransform(aug1, aug2)
)

memory_ds = R22_Dataset(
    os.path.join(config["dataset_path"], "train_preprocessed.h5"),
    input_label="iq_data",
    target_label="angles"
)

test_ds = R22_Dataset(
    os.path.join(config["dataset_path"], "test_preprocessed.h5"),
    input_label="iq_data",
    target_label="angles"
)

# test with a small dataset
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
idx = torch.randperm(len(train_ds))[:225*1000]
train_ds = torch.utils.data.Subset(train_ds, idx)
idx = torch.randperm(len(memory_ds))[:225*1000]
memory_ds = torch.utils.data.Subset(memory_ds, idx)
idx = torch.randperm(len(test_ds))[:225*10]
test_ds = torch.utils.data.Subset(test_ds, idx)

prefetch_factor = 2
num_workers = 8
persistent_workers = True
pin_memory = True

train_loader = DataLoader(
    train_ds,
    batch_size=config["batch_size"],
    shuffle=True,
    prefetch_factor=prefetch_factor,
    num_workers=num_workers,
    persistent_workers=persistent_workers,
    pin_memory=pin_memory,
    drop_last=True
)

memory_loader = DataLoader(
    memory_ds,
    batch_size=config["batch_size"],
    shuffle=False,
    prefetch_factor=prefetch_factor,
    num_workers=num_workers,
    persistent_workers=persistent_workers,
    pin_memory=pin_memory,
    drop_last=False
)
test_loader = DataLoader(
    test_ds,
    batch_size=config["batch_size"],
    shuffle=False,
    prefetch_factor=prefetch_factor,
    num_workers=num_workers,
    persistent_workers=persistent_workers,
    pin_memory=pin_memory,
    drop_last=False
)

(xb1, xb2), yb = next(iter(train_loader))
print(xb1.shape, xb2.shape, yb.shape)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.Size([1024, 4, 2, 1024]) torch.Size([1024, 4, 2, 1024]) torch.Size([1024])


In [9]:
# test the model
model = SimCLR(
    n_chans=config["n_chans"],
    dim=config["dim"],
    mlp_dim=config["mlp_dim"],
    temperature=config["temperature"],
    drop_rate=config["drop_rate"],
    drop_path_rate=config["drop_path_rate"]
).to(device)

print(
    f"Model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters")

loss = model(xb1.to(device), xb2.to(device))
print(loss)

Model has 16,692,864 trainable parameters
tensor(5.9410, device='cuda:0', grad_fn=<NllLossBackward0>)


In [10]:
class Tracker:
    """
    A class to track the best value of a metric.

    :param metric: The name of the metric to track. If 'loss' is in the metric name, the goal is to minimize it.
    :type metric: str
    :param mode: The mode of tracking. Can be 'auto', 'min', or 'max'. Default is 'auto'.
    :type mode: str, optional
    """

    def __init__(self, metric, mode='auto'):
        self.metric = metric
        self.mode = mode
        self.mode_dict = {
            'auto': np.less if 'loss' in metric else np.greater,
            'min': np.less,
            'max': np.greater
        }
        self.operator = self.mode_dict[mode]
        self._best = np.inf if self.operator == np.less else -np.inf

    @property
    def best(self):
        return self._best

    @best.setter
    def best(self, value):
        self._best = value

In [11]:
def knn_predict(feature, feature_bank, feature_labels, num_classes, k=200, t=0.1):
    # feature is [b, d], feature_bank is [d, n] and feature_labels is [n]
    sim_mat = torch.mm(feature, feature_bank)  # [b, n]
    sim_weight, sim_indices = sim_mat.topk(k, dim=-1)  # [b, k]
    sim_labels = torch.gather(feature_labels.expand(
        feature.size(0), -1), dim=-1, index=sim_indices)
    sim_weight = (sim_weight / t).exp()

    # count for each class
    one_hot = torch.zeros(feature.size(0) * k, num_classes).to(feature.device)
    one_hot = one_hot.scatter(
        dim=-1, index=sim_labels.view(-1, 1), value=1.0)  # [b*k, num_classes]
    pred_scores = torch.sum(one_hot.view(feature.size(0), -1, num_classes) * sim_weight.unsqueeze(dim=-1), dim=1)  # weighted scores [b, num_classes] # noqa

    pred_labels = pred_scores.argsort(
        dim=-1, descending=True)  # [b, num_classes]
    return pred_labels


@torch.no_grad()
def knn_evaluate(model, memory_loader, test_loader, epoch, config, pbar, writer, device):
    feature_bank, feature_labels = [], []
    encoder = copy.deepcopy(model.backbone)
    encoder.fc = nn.Identity()
    encoder = encoder.to(device)
    encoder.eval()
    for x, y in tqdm(memory_loader, desc="Extracting features", leave=False):
        x, y = x.to(device), y.to(device)
        feature = encoder(x)
        # normalize the feature
        feature = F.normalize(feature, dim=-1)
        feature_bank.append(feature)
        feature_labels.append(y)
    feature_bank = torch.cat(feature_bank, dim=0).t().contiguous()  # [d, n]
    feature_labels = torch.cat(feature_labels, dim=0)  # [n]

    # loop over the test set
    total_num, top1, top5 = 0, 0, 0
    for x, y in tqdm(test_loader, desc="Testing", leave=False):
        x, y = x.to(device), y.to(device)
        feature = encoder(x)
        # normalize the feature
        feature = F.normalize(feature, dim=-1)
        pred_labels = knn_predict(feature, feature_bank, feature_labels,
                                  num_classes=config["nclasses"], k=config["knn_k"], t=config["knn_t"])
        top1 += (pred_labels[:, 0] == y).sum().item()
        top5 += (pred_labels[:, :5] == y.unsqueeze(1)).sum().item()
        total_num += y.size(0)
    pbar.write(
        f"Epoch [{epoch}/{config['num_epochs']}] Acc@1: {top1 / total_num * 100:.2f}%, Acc@5: {top5 / total_num * 100:.2f}%"
    )
    writer.add_scalar(
        f"test/top1", top1 / total_num * 100, epoch)
    writer.add_scalar(
        f"test/top5", top5 / total_num * 100, epoch)

In [12]:
class LARS(torch.optim.Optimizer):
    """
    LARS optimizer, no rate scaling or weight decay for parameters <= 1D.
    """

    def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001):
        defaults = dict(lr=lr, weight_decay=weight_decay,
                        momentum=momentum, trust_coefficient=trust_coefficient)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self):
        for g in self.param_groups:
            for p in g['params']:
                dp = p.grad

                if dp is None:
                    continue

                if p.ndim > 1:  # if not normalization gamma/beta or bias
                    dp = dp.add(p, alpha=g['weight_decay'])
                    param_norm = torch.norm(p)
                    update_norm = torch.norm(dp)
                    one = torch.ones_like(param_norm)
                    q = torch.where(param_norm > 0.,
                                    torch.where(update_norm > 0,
                                                (g['trust_coefficient'] * param_norm / update_norm), one),
                                    one)
                    dp = dp.mul(q)

                param_state = self.state[p]
                if 'mu' not in param_state:
                    param_state['mu'] = torch.zeros_like(p)
                mu = param_state['mu']
                mu.mul_(g['momentum']).add_(dp)
                p.add_(mu, alpha=-g['lr'])

In [13]:
# add the label encoder angles to the config
with open("label_encoder_angles.yaml", "r") as f:
    label_encoder_angles = yaml.safe_load(f)
config.update(**label_encoder_angles)
# dump the config
path = os.path.join("./experiments", config["experiment_name"])
os.makedirs(path, exist_ok=True)
with open(os.path.join(path, "config.yaml"), "w") as f:
    yaml.dump(config, f, default_flow_style=False)

optim = LARS(
    model.parameters(),
    lr=config["learning_rate"] * config["batch_size"] / 256,
    weight_decay=config["weight_decay"],
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optim,
    T_max=config["num_epochs"],
    eta_min=1e-6,
)

scaler = torch.amp.GradScaler(device=device.type, enabled=True)
writer = SummaryWriter(
    log_dir=os.path.join("./experiments/", config["experiment_name"]),
)
tracker = Tracker("loss/epoch", mode="min")

start_epoch = 0
# resume training from a checkpoint if it exists
checkpoint_path = config["checkpoint_path"]
if checkpoint_path and os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint["state_dict"])
    optim.load_state_dict(checkpoint["opt_state_dict"])
    scheduler.load_state_dict(checkpoint["sch_state_dict"])
    start_epoch = checkpoint["epoch"]
    print(f"Resuming training from epoch {start_epoch}")

step = 0
with tqdm(range(start_epoch, config["num_epochs"])) as master_bar:
    for epoch in master_bar:
        model.train()
        avg_loss = 0.0
        with tqdm(train_loader, leave=False) as pbar:
            for (xb1, xb2), _ in pbar:
                xb1, xb2 = xb1.to(device), xb2.to(device)
                optim.zero_grad()

                with torch.amp.autocast(device_type=device.type,
                                        enabled=True):
                    loss = model(xb1, xb2)

                scaler.scale(loss).backward()
                # # clip gradients
                # scaler.unscale_(optim)
                # norm = torch.nn.utils.clip_grad_norm_(
                #     model.parameters(), 1.0, norm_type=2
                # )
                scaler.step(optim)
                scaler.update()

                avg_loss += loss.item()

                pbar.set_postfix(
                    {"loss/step": loss.item(),
                     #  "norm": norm.item()
                     })
                writer.add_scalar("loss/step", loss.item(), step)
                # writer.add_scalar("norm/step", norm.item(), step)
                step += 1
                pbar.update(1)

        avg_loss /= len(train_loader)
        writer.add_scalar("loss/epoch", avg_loss, epoch)
        master_bar.write(f"Epoch {epoch}: loss = {avg_loss:.4f}")
        scheduler.step()
        writer.add_scalar(
            "learning_rate/epoch",
            optim.param_groups[0]["lr"],
            epoch,
        )

        # evaluate the model
        knn_evaluate(
            model,
            memory_loader,
            test_loader,
            epoch,
            config,
            master_bar,
            writer,
            device,
        )

        writer.flush()
        if tracker.operator(avg_loss, tracker.best):
            tracker.best = avg_loss
            # Save the model checkpoint
            checkpoint_path = os.path.join(
                "./experiments", f"{config['experiment_name']}/weights.pth"
            )
            torch.save(model.state_dict(), checkpoint_path)
            master_bar.write(f"Model saved to {checkpoint_path}")
        # save the latest checkpoint
        checkpoint_path = os.path.join(
            "./experiments", f"{config['experiment_name']}/last_checkpoint.pt")
        torch.save({
            "state_dict": model.state_dict(),
            "opt_state_dict": optim.state_dict(),
            "sch_state_dict": scheduler.state_dict(),
            "epoch": epoch}, checkpoint_path)
        master_bar.write(
            f"Latest checkpoint saved to {checkpoint_path} at epoch {epoch}"
        )

writer.close()
print("Training complete.")

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [01:49<?, ?it/s]

Epoch 0: loss = 3.3734


  0%|          | 0/100 [03:35<?, ?it/s]

Epoch [0/100] Acc@1: 6.93%, Acc@5: 44.49%
Model saved to ./experiments/simclr_2025-05-14_06-10-25/weights.pth


  1%|          | 1/100 [03:36<5:56:33, 216.09s/it]

Latest checkpoint saved to ./experiments/simclr_2025-05-14_06-10-25/last_checkpoint.pt at epoch 0


  1%|          | 1/100 [05:24<5:56:33, 216.09s/it]

Epoch 1: loss = 2.3000


  1%|          | 1/100 [07:09<5:56:33, 216.09s/it]

Epoch [1/100] Acc@1: 8.13%, Acc@5: 48.89%


  1%|          | 1/100 [07:09<5:56:33, 216.09s/it]

Model saved to ./experiments/simclr_2025-05-14_06-10-25/weights.pth


  2%|▏         | 2/100 [07:10<5:51:09, 214.99s/it]

Latest checkpoint saved to ./experiments/simclr_2025-05-14_06-10-25/last_checkpoint.pt at epoch 1


  2%|▏         | 2/100 [08:58<5:51:09, 214.99s/it]

Epoch 2: loss = 1.8953


  2%|▏         | 2/100 [10:44<5:51:09, 214.99s/it]

Epoch [2/100] Acc@1: 6.89%, Acc@5: 47.20%


  2%|▏         | 2/100 [10:44<5:51:09, 214.99s/it]

Model saved to ./experiments/simclr_2025-05-14_06-10-25/weights.pth


  3%|▎         | 3/100 [10:44<5:47:19, 214.84s/it]

Latest checkpoint saved to ./experiments/simclr_2025-05-14_06-10-25/last_checkpoint.pt at epoch 2


  3%|▎         | 3/100 [12:33<5:47:19, 214.84s/it]

Epoch 3: loss = 1.6567


  3%|▎         | 3/100 [14:19<5:47:19, 214.84s/it]

Epoch [3/100] Acc@1: 6.04%, Acc@5: 44.09%


  3%|▎         | 3/100 [14:19<5:47:19, 214.84s/it]

Model saved to ./experiments/simclr_2025-05-14_06-10-25/weights.pth


  4%|▍         | 4/100 [14:19<5:43:43, 214.83s/it]

Latest checkpoint saved to ./experiments/simclr_2025-05-14_06-10-25/last_checkpoint.pt at epoch 3


  4%|▍         | 4/100 [16:08<5:43:43, 214.83s/it]

Epoch 4: loss = 1.5180


  4%|▍         | 4/100 [17:53<5:43:43, 214.83s/it]

Epoch [4/100] Acc@1: 5.33%, Acc@5: 45.82%


  4%|▍         | 4/100 [17:54<5:43:43, 214.83s/it]

Model saved to ./experiments/simclr_2025-05-14_06-10-25/weights.pth


  5%|▌         | 5/100 [17:54<5:40:06, 214.81s/it]

Latest checkpoint saved to ./experiments/simclr_2025-05-14_06-10-25/last_checkpoint.pt at epoch 4


  5%|▌         | 5/100 [19:43<5:40:06, 214.81s/it]

Epoch 5: loss = 1.4044


  5%|▌         | 5/100 [21:28<5:40:06, 214.81s/it]

Epoch [5/100] Acc@1: 4.98%, Acc@5: 43.38%


  5%|▌         | 5/100 [21:29<5:40:06, 214.81s/it]

Model saved to ./experiments/simclr_2025-05-14_06-10-25/weights.pth


  6%|▌         | 6/100 [21:29<5:36:35, 214.84s/it]

Latest checkpoint saved to ./experiments/simclr_2025-05-14_06-10-25/last_checkpoint.pt at epoch 5


  6%|▌         | 6/100 [23:17<5:36:35, 214.84s/it]

Epoch 6: loss = 1.3297


  6%|▌         | 6/100 [25:03<5:36:35, 214.84s/it]

Epoch [6/100] Acc@1: 5.47%, Acc@5: 44.80%


  6%|▌         | 6/100 [25:03<5:36:35, 214.84s/it]

Model saved to ./experiments/simclr_2025-05-14_06-10-25/weights.pth


  7%|▋         | 7/100 [25:04<5:32:53, 214.76s/it]

Latest checkpoint saved to ./experiments/simclr_2025-05-14_06-10-25/last_checkpoint.pt at epoch 6


  7%|▋         | 7/100 [26:52<5:32:53, 214.76s/it]

Epoch 7: loss = 1.2629


  7%|▋         | 7/100 [28:38<5:32:53, 214.76s/it]

Epoch [7/100] Acc@1: 6.40%, Acc@5: 43.56%


  7%|▋         | 7/100 [28:38<5:32:53, 214.76s/it]

Model saved to ./experiments/simclr_2025-05-14_06-10-25/weights.pth


  8%|▊         | 8/100 [28:39<5:29:24, 214.83s/it]

Latest checkpoint saved to ./experiments/simclr_2025-05-14_06-10-25/last_checkpoint.pt at epoch 7


  8%|▊         | 8/100 [30:27<5:29:24, 214.83s/it]

Epoch 8: loss = 1.2105


  8%|▊         | 8/100 [32:13<5:29:24, 214.83s/it]

Epoch [8/100] Acc@1: 5.82%, Acc@5: 45.60%


  8%|▊         | 8/100 [32:13<5:29:24, 214.83s/it]

Model saved to ./experiments/simclr_2025-05-14_06-10-25/weights.pth


  9%|▉         | 9/100 [32:13<5:25:45, 214.79s/it]

Latest checkpoint saved to ./experiments/simclr_2025-05-14_06-10-25/last_checkpoint.pt at epoch 8


  9%|▉         | 9/100 [34:02<5:25:45, 214.79s/it]

Epoch 9: loss = 1.1694


  9%|▉         | 9/100 [35:15<5:56:25, 235.01s/it]


KeyboardInterrupt: 