In [2]:
# kernel auto reload modules when the underlying code is chaneged, instead of having to reset the runtime.
%load_ext autoreload
%autoreload 2

from copy import deepcopy
import os
from tqdm import tqdm
from glob import glob

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim
import torchvision
import torchvision.transforms as ttf
from torch.cuda.amp import GradScaler, autocast
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.metrics import accuracy_score, roc_auc_score
import hydra
from omegaconf import OmegaConf
import wandb

from models.convnext import convnext_t, my_convnext
from datasets.triplet import TripletDataset
from datasets.classification import ClassificationTestSet
from datasets.verification import VerificationDataset
from datasets.transform import AlbumTransforms, train_transforms, val_transforms
from utils.utils import weight_decay_custom, compute_kl_loss, SAM
from run import train, test, inference, face_embedding, verification_inference, gen_cls_submission, gen_ver_submission

In [3]:
BASE_DIR = '/shared/youngkim/hw2p2'
CLS_DIR = os.path.join(BASE_DIR, '11-785-s22-hw2p2-classification')
VER_DIR = os.path.join(BASE_DIR, '11-785-s22-hw2p2-verification')

CLS_TRAIN_DIR = os.path.join(CLS_DIR, "classification/classification/train") # This is a smaller subset of the data. Should change this to classification/classification/train
CLS_VAL_DIR = os.path.join(CLS_DIR, "classification/classification/dev")
CLS_TEST_DIR = os.path.join(CLS_DIR, "classification/classification/test")

In [4]:
train_dataset = TripletDataset(CLS_TRAIN_DIR, transform=AlbumTransforms(train_transforms))
val_dataset = TripletDataset(CLS_VAL_DIR, transform=AlbumTransforms(val_transforms))

train_loader = DataLoader(train_dataset, batch_size=256,
                        shuffle=True, drop_last=True, num_workers=2)
valid_loader = DataLoader(val_dataset, batch_size=256, shuffle=False,
                        drop_last=True, num_workers=1)


In [5]:
for item in train_loader:
    break

In [13]:
item[1][1].shape

torch.Size([256])