In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import os
os.chdir("/content/drive/MyDrive/Main Project VDT")

In [4]:
# identification list
ids = []
for id in os.listdir("WebFace260M"):
    if id.startswith("0_0_"):
        ids.append(id)
print("Number of id:", len(ids))

Number of id: 1000


In [5]:
# data set: set[id] = {file_name.jpg}
instances = {}
for id in ids:
    instances[id] = []
    for instance in os.listdir("WebFace260M/"+id):
        instances[id].append(instance)

In [None]:
def train_val_test_split()
stats = list([len(instance) for instance in instances.values()])
print("Number of data: ", sum(stats))
stats.count(3)

Number of data:  19396


90

In [None]:
from PIL import Image
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from tqdm import tqdm

import warnings
warnings.filterwarnings("ignore")

class Imdataset(Dataset):
    def __init__(self, instances, transform=None):
        super().__init__()
        self.instances = instances
        if transform is not None:
            self.instances = [(transform(image), label) for image, label in instances] 
    
    def __len__(self):
        return len(self.instances)

    def __getitem__(self, index):
        return self.instances[index]

    def collate_fn(self, batch):
        images = []
        labels = []
        for inst in batch:
            images.append(inst[0])
            labels.append(inst[1])
        images = torch.tensor(images)
        labels = torch.tensor(labels)
        return [images, labels]

def collect_images(task, split_ratio, number_id):
    dataset = []
    assert task in ["Train", "Valid", "Test"]
    folder_dir = "WebFace260M"
    assert number_id < len(os.listdir(folder_dir))
    for id in tqdm(os.listdir(folder_dir)[:number_id]):
        id_images = os.listdir(folder_dir + "/" + id)
        id_images_num = len(id_images)
        test_num = int(split_ratio["Test"]*id_images_num)
        valid_num = int(split_ratio["Valid"]*id_images_num)
        if task=="Train":
            images = id_images[:id_images_num - valid_num - test_num]
        elif task=="Valid":
            images = id_images[id_images_num - valid_num - test_num:id_images_num - test_num]
        else:
            images = id_images[id_images_num - test_num:]
        
        for images in os.listdir(folder_dir + "/" + id):
            if images.endswith(".jpg"):
                dataset.append(tuple([Image.open(folder_dir + "/" + id + "/" + images).convert('RGB'), int(id[4:])]))

    return dataset

def prepare_data(number_id=10, batch_size=4, num_workers=2, split_ratio=None, train_sample_size=None, valid_sample_size=None, test_sample_size=None):
    train_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((90, 90)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomResizedCrop((90, 90), scale=(0.8, 1.0), ratio=(0.75, 1.3333333333333333), interpolation=2),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    trainset = Imdataset(collect_images("Train", split_ratio, number_id), train_transform)
    if train_sample_size is not None:
        indices = torch.randperm(len(trainset))[:train_sample_size]
        trainset = torch.utils.data.Subset(trainset, indices)
    
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                            shuffle=True, num_workers=num_workers)
    
    test_transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Resize((90, 90)),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


    validset = Imdataset(collect_images("Valid", split_ratio, number_id), test_transform)
    
    if valid_sample_size is not None:
        indices = torch.randperm(len(validset))[:valid_sample_size]
        validset = torch.utils.data.Subset(validset, indices)
    
    validloader = torch.utils.data.DataLoader(validset, batch_size=batch_size,
                                            shuffle=False, num_workers=num_workers)


    testset = Imdataset(collect_images("Test", split_ratio, number_id), test_transform)
    
    if test_sample_size is not None:
        # Randomly sample a subset of the test set
        indices = torch.randperm(len(testset))[:test_sample_size]
        testset = torch.utils.data.Subset(testset, indices)
    
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                            shuffle=False, num_workers=num_workers)

    return trainloader, validloader, testloader

In [None]:
import gc
gc.collect()

3528

In [None]:
split_ratio = {"Train": 0.7, "Valid": 0.2, "Test": 0.1}
trainloader, valloader, testloader= prepare_data(number_id=10, batch_size=4, split_ratio=split_ratio)

100%|██████████| 10/10 [00:00<00:00, 10.33it/s]
100%|██████████| 10/10 [00:00<00:00, 10.11it/s]
100%|██████████| 10/10 [00:00<00:00, 10.17it/s]


In [13]:
from models.vit import VisionTransformer
import json
from models.block import ParallelScalingBlock
from models.norm import RmsNorm


configfile = "configs/test_config.json"

config = {
    "patch_size": 14,
    "embed_dim": 48*10, 
    "depth": 48,
    "num_heads": 48,
    "pre_norm": 1,
    "no_embed_class": 1,
    "qkv_bias": 0,
    "qk_norm": 1
}

with open(configfile, 'w') as f:
    sth = json.dumps(config, indent=4)
    f.write(sth)


with open(configfile, 'r') as f:
    config = json.load(f)

model_own = dict(
    patch_size=config["patch_size"], 
    embed_dim=config["embed_dim"], 
    depth=config["depth"], 
    num_heads=config["num_heads"], 
    pre_norm=config["pre_norm"], 
    no_embed_class=config["no_embed_class"],
    norm_layer=RmsNorm, 
    block_fn=ParallelScalingBlock, 
    qkv_bias=config["qkv_bias"], 
    qk_norm=config["qk_norm"],
)

vit = VisionTransformer(**model_own)

In [14]:
total_params = sum(params.numel() for params in vit.parameters())
total_params

133760200

In [4]:
6144/48

128.0

In [26]:
import gc
gc.collect()

0

In [1]:
# memory before storing model
!free -h

              total        used        free      shared  buff/cache   available
Mem:          251Gi        64Gi       171Gi       328Mi        16Gi       185Gi
Swap:          11Gi       2.9Gi       9.1Gi


In [3]:
# memory after storing model
!free -h

              total        used        free      shared  buff/cache   available
Mem:          251Gi        67Gi       168Gi       328Mi        16Gi       182Gi
Swap:          11Gi       2.9Gi       9.1Gi


In [25]:
!nvidia-smi

Tue May 23 21:20:57 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.106.00   Driver Version: 460.106.00   CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  A40                 On   | 00000000:81:00.0 Off |                    0 |
|  0%   61C    P0    93W / 300W |  25379MiB / 45634MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [21]:
!nvidia-smi

Tue May 23 21:12:34 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.106.00   Driver Version: 460.106.00   CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  A40                 On   | 00000000:81:00.0 Off |                    0 |
|  0%   77C    P0   115W / 300W |  25379MiB / 45634MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces