In [1]:
import argparse

import torch
import timm
import torchvision
from torch.utils.data import DataLoader

from lightning.pytorch import Trainer
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import ModelCheckpoint

from misc import collate_fn, Augment_v2
from model import MoCo


In [7]:
train_dataset = torchvision.datasets.ImageFolder(root="../../data/imagenet")
img_size = train_dataset[0][0].size[0]

augment = Augment_v2(image_size=img_size)

def collate_fn(batch, augment):
    """
    batch: list of (img, label)
    augment: an instance of Augment_v2
    """
    x_q, x_k = [], []
    labels = []
    for img, label in batch:
        x1 = augment(img)
        x2 = augment(img)
        x_q.append(x1)
        x_k.append(x2)
        labels.append(label)

    # Stack into tensors
    x_q = torch.stack(x_q, dim=0)
    x_k = torch.stack(x_k, dim=0)

    return x_q, x_k, labels


train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=64,
        shuffle=True,
        collate_fn=lambda batch: collate_fn(batch, augment),
        num_workers=10,
        pin_memory=True
    )
    

In [8]:
batch = next(iter(train_loader))

In [9]:
batch[2]

[68,
 41,
 115,
 173,
 67,
 138,
 15,
 151,
 148,
 90,
 55,
 119,
 163,
 112,
 198,
 17,
 157,
 166,
 122,
 138,
 159,
 92,
 89,
 67,
 191,
 72,
 65,
 7,
 172,
 83,
 47,
 4,
 66,
 142,
 162,
 159,
 187,
 22,
 196,
 131,
 176,
 119,
 154,
 195,
 5,
 189,
 85,
 33,
 135,
 5,
 192,
 84,
 89,
 95,
 62,
 147,
 0,
 49,
 152,
 66,
 160,
 122,
 149,
 125]