<a href="https://colab.research.google.com/github/Quillbolt/colabnotebook/blob/main/Untitled5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install vit-pytorch

In [None]:
!pip install linformer
!pip install byol-pytorch

In [None]:
!unzip /content/drive/MyDrive/face1.zip -d /content/

In [None]:
import os

In [None]:
for dir,_, filenames in os.walk('/content/face1'):
    if len(filenames) > 5:
        print(dir + " has crossed limit, " + "total files: " + str(len(filenames)))


In [None]:
import torch
from vit_pytorch.efficient import ViT
from linformer import Linformer

efficient_transformer = Linformer(
    dim = 512,
    seq_len = 4096 + 1,  # 64 x 64 patches + 1 cls token
    depth = 12,
    heads = 8,
    k = 256
)

v = ViT(
    dim = 512,
    image_size = 2048,
    patch_size = 32,
    num_classes = 1000,
    transformer = efficient_transformer
)

img = torch.randn(1, 3, 2048, 2048) # your high resolution picture
v(img) # (1, 1000)

In [None]:
from __future__ import print_function

import glob
from itertools import chain
import os
import random
import zipfile

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from linformer import Linformer
from PIL import Image
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm.notebook import tqdm

from vit_pytorch.efficient import ViT

In [None]:
print(f"Torch: {torch.__version__}")

Torch: 1.7.0+cu101


In [None]:
batch_size = 64
epochs = 20
lr = 3e-5
gamma = 0.7
seed = 42

In [None]:
def make_weights_for_balanced_classes(images, nclasses):                        
    count = [0] * nclasses                                                      
    for item in images:                                                         
        count[item[1]] += 1                                                     
    weight_per_class = [0.] * nclasses                                      
    N = float(sum(count))                                                   
    for i in range(nclasses):                                                   
        weight_per_class[i] = N/float(count[i])                                 
    weight = [0] * len(images)                                              
    for idx, val in enumerate(images):                                          
        weight[idx] = weight_per_class[val[1]]                                  
    return weight  

In [None]:
train_transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ]
)

In [None]:
dataset 

Dataset ImageFolder
    Number of datapoints: 27854
    Root location: ./face1/
    StandardTransform
Transform: Compose(
               Resize(size=(224, 224), interpolation=PIL.Image.BILINEAR)
               RandomResizedCrop(size=(224, 224), scale=(0.08, 1.0), ratio=(0.75, 1.3333), interpolation=PIL.Image.BILINEAR)
               RandomHorizontalFlip(p=0.5)
               ToTensor()
           )

In [None]:
train_length=int(0.7* len(dataset))

test_length=len(dataset)-train_length

In [None]:
train_dataset,test_dataset=torch.utils.data.random_split(dataset,(train_length,test_length))

In [None]:
labels = [path.split('/')[-1].split('.')[0] for path in train_list]

In [None]:
labels

In [None]:
valid_loader =torch.utils.data.DataLoader(test_dataset,
        batch_size=batch_size, shuffle=False, sampler=None)

In [None]:
dataset = torchvision.datasets.ImageFolder('./face1/',train_transforms)

In [None]:
weights = make_weights_for_balanced_classes(dataset.imgs, len(dataset.classes))                                                                
weights = torch.DoubleTensor(weights)                                       
sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))                     
                                                                                

In [None]:
len(dataset.classes)

In [None]:
train_loader = DataLoader(dataset = train_dataset, batch_size=batch_size, shuffle=True )

In [None]:
print(len(dataset), len(train_loader))

27854 436


In [None]:
os.environ['CUDA_LAUNCH_BLOCKING']="1"

In [None]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cuda:0')


In [None]:
efficient_transformer = Linformer(
    dim=128,
    seq_len=49+1,  # 7x7 patches + 1 cls-token
    depth=12,
    heads=8,
    k=64
)

In [None]:
model = ViT(
    dim=128,
    image_size=224,
    patch_size=32,
    num_classes=288,
    transformer=efficient_transformer,
    channels=3,
).to(device)

In [None]:
# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

In [None]:
for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0

    for data, label in tqdm(train_loader):
        data = data.to(device)
        label = label.to(device)

        output = model(data)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)

    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in valid_loader:
            data = data.to(device)
            label = label.to(device)

            val_output = model(data)
            val_loss = criterion(val_output, label)

            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(valid_loader)
            epoch_val_loss += val_loss / len(valid_loader)

    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )

HBox(children=(FloatProgress(value=0.0, max=305.0), HTML(value='')))


Epoch : 1 - loss : 4.1770 - acc: 0.1100 - val_loss : 4.1321 - val_acc: 0.1128



HBox(children=(FloatProgress(value=0.0, max=305.0), HTML(value='')))


Epoch : 2 - loss : 4.0982 - acc: 0.1172 - val_loss : 4.0707 - val_acc: 0.1193



HBox(children=(FloatProgress(value=0.0, max=305.0), HTML(value='')))


Epoch : 3 - loss : 4.0470 - acc: 0.1207 - val_loss : 4.0287 - val_acc: 0.1264



HBox(children=(FloatProgress(value=0.0, max=305.0), HTML(value='')))


Epoch : 4 - loss : 4.0013 - acc: 0.1289 - val_loss : 4.0009 - val_acc: 0.1341



HBox(children=(FloatProgress(value=0.0, max=305.0), HTML(value='')))


Epoch : 5 - loss : 3.9671 - acc: 0.1344 - val_loss : 3.9628 - val_acc: 0.1368



HBox(children=(FloatProgress(value=0.0, max=305.0), HTML(value='')))


Epoch : 6 - loss : 3.9260 - acc: 0.1398 - val_loss : 3.9242 - val_acc: 0.1406



HBox(children=(FloatProgress(value=0.0, max=305.0), HTML(value='')))


Epoch : 7 - loss : 3.8921 - acc: 0.1440 - val_loss : 3.9155 - val_acc: 0.1482



HBox(children=(FloatProgress(value=0.0, max=305.0), HTML(value='')))


Epoch : 8 - loss : 3.8614 - acc: 0.1509 - val_loss : 3.8625 - val_acc: 0.1538



HBox(children=(FloatProgress(value=0.0, max=305.0), HTML(value='')))


Epoch : 9 - loss : 3.8315 - acc: 0.1548 - val_loss : 3.8356 - val_acc: 0.1627



HBox(children=(FloatProgress(value=0.0, max=305.0), HTML(value='')))


Epoch : 10 - loss : 3.7858 - acc: 0.1616 - val_loss : 3.8253 - val_acc: 0.1637



HBox(children=(FloatProgress(value=0.0, max=305.0), HTML(value='')))


Epoch : 11 - loss : 3.7542 - acc: 0.1695 - val_loss : 3.7822 - val_acc: 0.1726



HBox(children=(FloatProgress(value=0.0, max=305.0), HTML(value='')))


Epoch : 12 - loss : 3.7280 - acc: 0.1731 - val_loss : 3.7592 - val_acc: 0.1766



HBox(children=(FloatProgress(value=0.0, max=305.0), HTML(value='')))


Epoch : 13 - loss : 3.6986 - acc: 0.1797 - val_loss : 3.7216 - val_acc: 0.1778



HBox(children=(FloatProgress(value=0.0, max=305.0), HTML(value='')))


Epoch : 14 - loss : 3.6678 - acc: 0.1857 - val_loss : 3.6988 - val_acc: 0.1920



HBox(children=(FloatProgress(value=0.0, max=305.0), HTML(value='')))


Epoch : 15 - loss : 3.6336 - acc: 0.1901 - val_loss : 3.6750 - val_acc: 0.1935



HBox(children=(FloatProgress(value=0.0, max=305.0), HTML(value='')))


Epoch : 16 - loss : 3.6158 - acc: 0.1941 - val_loss : 3.6672 - val_acc: 0.1950



HBox(children=(FloatProgress(value=0.0, max=305.0), HTML(value='')))


Epoch : 17 - loss : 3.5758 - acc: 0.2023 - val_loss : 3.6389 - val_acc: 0.1961



HBox(children=(FloatProgress(value=0.0, max=305.0), HTML(value='')))


Epoch : 18 - loss : 3.5696 - acc: 0.2041 - val_loss : 3.6275 - val_acc: 0.1973



HBox(children=(FloatProgress(value=0.0, max=305.0), HTML(value='')))


Epoch : 19 - loss : 3.5287 - acc: 0.2105 - val_loss : 3.6000 - val_acc: 0.2054



HBox(children=(FloatProgress(value=0.0, max=305.0), HTML(value='')))


Epoch : 20 - loss : 3.5023 - acc: 0.2155 - val_loss : 3.5880 - val_acc: 0.2112



In [None]:
!nvidia-smi

Mon Nov 30 09:42:27 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.38       Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   59C    P0    29W /  70W |   1673MiB / 15079MiB |      0%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces