In [36]:
import argparse
import math
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.data import DataLoader
import numpy as np
import pandas as pd
import albumentations as A
from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from gan import Generator
from gan import Discriminator
from gan import weights_init
from dataset import HandDataset
import cv2
import timm
from ultralytics import YOLO
from PIL import Image

In [2]:
ds_path = 'archive/Hands/Hands'
img_size = 256
random_state = 69
device = 'cuda' if torch.cuda.is_available else 'cpu'
z_size = 512

df = pd.read_csv('archive/HandInfo.csv')
df

Unnamed: 0,id,age,gender,skinColor,accessories,nailPolish,aspectOfHand,imageName,irregularities
0,0,27,male,fair,0,0,dorsal right,Hand_0000002.jpg,0
1,0,27,male,fair,0,0,dorsal right,Hand_0000003.jpg,0
2,0,27,male,fair,0,0,dorsal right,Hand_0000004.jpg,0
3,0,27,male,fair,0,0,dorsal right,Hand_0000005.jpg,0
4,0,27,male,fair,0,0,dorsal right,Hand_0000006.jpg,0
...,...,...,...,...,...,...,...,...,...
11071,1589,22,female,fair,0,0,palmar left,Hand_0011740.jpg,0
11072,1589,22,female,fair,0,0,palmar left,Hand_0011741.jpg,0
11073,1589,22,female,fair,0,0,palmar left,Hand_0011742.jpg,0
11074,1589,22,female,fair,0,0,palmar left,Hand_0011743.jpg,0


In [3]:
workers = 4
batch_size = 32
image_size = 256
nc = 3
nz = 500
ngf = 64
ndf = 64
num_epochs = 300
lr = 0.0002
beta1 = 0.5
ngpu = 1

In [None]:
yolo = YOLO('yolo11n.pt')
yolo = yolo.to(device)

res = yolo.train(data='/home/anton/PycharmProjects/another_projects/bio/hand-palm-detection.v2i.yolov11/data.yaml', epochs=200)

results = yolo([os.path.join(ds_path, fn) for fn in os.listdir(ds_path)][:3])
results

In [None]:
yolo.val()

In [None]:
for result in results:
    boxes = result.boxes
    print(boxes)

In [4]:
from pathlib import Path

path_to_model = '/home/anton/PycharmProjects/another_projects/bio/best.pt'
best_yolo = YOLO(path_to_model)

if not Path('cropped_images').exists():
    os.mkdir('cropped_images')
    
all_fns = [os.path.join(ds_path, fn) for fn in os.listdir(ds_path)]

for fn in tqdm(all_fns):
    result = best_yolo.predict(fn, show=False, verbose=False)
    boxes = result[0].boxes.xyxy
    if len(boxes) > 0:
        x1, y1, x2, y2 = list(map(int, boxes.detach().cpu().numpy()[0]))
    
        img_name = fn.split('/')[-1]
        img = cv2.imread(fn)
        
        x1, x2 = max(int(x1 - 0.2 * img.shape[1]), 0), min(int(x2 + 0.2 * img.shape[1]), img.shape[1] - 2)
        y1, y2 = max(int(y1 - 0.2 * img.shape[0]), 0), min(int(y2 + 0.2 * img.shape[0]), img.shape[0] - 2)
        new_image = img[y1:y2+1, x1:x2+1]
        cv2.imwrite(os.path.join('cropped_images', img_name), new_image)
        continue
    
    image = cv2.imread(fn)
    img_name = fn.split('/')[-1]
    cv2.imwrite(os.path.join('cropped_images', img_name), image)

100%|█████████████████████████████████████| 11076/11076 [04:59<00:00, 37.03it/s]


In [5]:
def seed(random_state):
    torch.manual_seed(random_state)
    torch.cuda.manual_seed(random_state)
    torch.backends.cudnn.benchmark = True
    
seed(random_state)

In [43]:
def save_images(timages, folder):
    if not Path(folder).exists():
        os.mkdir(folder)
        
    for nimage in range(timages.shape[0]):
        timage = timages[nimage, ...]
        image = timage.permute(1, 2, 0).detach().cpu().numpy()
        cv2.imwrite(os.path.join(folder, str(nimage) + '.png'), image * 255)
        
        
def save_image(image, image_path):
    image *= 255
    cv2.imwrite(image_path, image)
    image = Image.open(image_path)
    image = image.resize((512, 512))
    image = np.array(image)
    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    cv2.imwrite(image_path, np.array(image))
    
def discriminator_loss_real(y_pred):
    return torch.tensor(sum([math.log(y_pred[i, 1].item() + 1e-6) 
                for i in range(y_pred.shape[0])]) / y_pred.shape[0], requires_grad=True)

def discriminator_loss_fake(y_pred):
    return torch.tensor(sum([math.log(1 - y_pred[i, 1].item() + 1e-6)
                for i in range(y_pred.shape[0])]) / y_pred.shape[0], requires_grad=True)


In [12]:
transforms = A.Compose([
    A.Resize(img_size, img_size),
#     A.VerticalFlip(p=0.6),
#     A.HorizontalFlip(p=0.6),
#     A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, border_mode=0, p=0.6),
#     A.Normalize(mean=0.5, std=0.5)
])

In [13]:
ds_path_ = 'cropped_images'
dataset = HandDataset(ds_path_, df, transform=transforms)
dataset[0].shape

torch.Size([3, 256, 256])

In [14]:
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=False)

In [15]:
from collections import OrderedDict
lr = 0.0002

load = True

generator = Generator(ngpu).to(device)
# discriminator = timm.create_model('resnet18', pretrained=False, num_classes=2).to(device)
discriminator = Discriminator(ngpu).to(device)
weights_init(generator)
weights_init(discriminator)

if load:
    generator.load_state_dict(torch.load('generator.pt'))
    discriminator.load_state_dict(torch.load('discriminator.pt'))

optimizer_gen = torch.optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_dis = torch.optim.Adam(discriminator.parameters(), lr=lr/2, betas=(beta1, 0.999))

criterion = nn.CrossEntropyLoss()

fixed_noise = torch.randn(64, nz, 1, 1, device=device)

real_label = 1
fake_label = 0

In [16]:
from pathlib import Path

if not Path('gan_gen').exists():
    os.mkdir('gan_gen')
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
for epoch in range(num_epochs):
    print(f'epoch: {epoch + 1}')
    with tqdm(dataloader, total=len(dataloader), position=0, leave=True) as pbar:
        for i, data in enumerate(pbar):
            discriminator.zero_grad()

            real_cpu = data.to(device)

            b_size = real_cpu.size(0)
            label = torch.LongTensor([1 for j in range(b_size)]).to(device)

            output = discriminator(real_cpu)

            loss_d_real = criterion(output, label)

            loss_d_real.backward()
            D_x = output.mean().item()
            noise = torch.randn(b_size, nz, 1, 1, device=device)
            fake = generator(noise)
            label.fill_(fake_label)
            output = discriminator(fake.detach())

            loss_d_fake = criterion(output, label)
            
            loss_d_fake.backward()
            D_G_z1 = output.mean().item()
            loss_d = loss_d_real + loss_d_fake

            optimizer_dis.step()

            generator.zero_grad()
            label.fill_(real_label) 
            output = discriminator(fake)
            loss_g = criterion(output, label)
    
            loss_g.backward()
            D_G_z2 = output.mean().item()
            optimizer_gen.step()

            pbar.set_postfix(
                OrderedDict(loss_d=loss_d.item(),
                           loss_g=loss_g.item())
            )

            G_losses.append(loss_g.item())
            D_losses.append(loss_d.item())

        if ((epoch == num_epochs-1) or (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = generator(fixed_noise).detach().cpu()
                save_images(fake, 'gan_gen/sample_' + str(epoch))
                
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
        
        iters += 1
    torch.save(generator.state_dict(), 'generator.pt')
    torch.save(discriminator.state_dict(), 'discriminator.pt')


Starting Training Loop...
epoch: 1


100%|██████████████| 347/347 [02:49<00:00,  2.05it/s, loss_d=0.00117, loss_g=16]


epoch: 2


100%|████████████| 347/347 [02:58<00:00,  1.94it/s, loss_d=9.09e-6, loss_g=24.7]


epoch: 3


100%|███████████| 347/347 [02:58<00:00,  1.94it/s, loss_d=0.000106, loss_g=15.7]


epoch: 4


100%|██████████████| 347/347 [02:57<00:00,  1.95it/s, loss_d=0.0124, loss_g=9.2]


epoch: 5


100%|█████████████| 347/347 [02:58<00:00,  1.95it/s, loss_d=0.0425, loss_g=9.61]


epoch: 6


100%|██████████████| 347/347 [02:58<00:00,  1.94it/s, loss_d=0.119, loss_g=11.9]


epoch: 7


100%|████████████| 347/347 [02:58<00:00,  1.94it/s, loss_d=0.00628, loss_g=14.4]


epoch: 8


100%|██████████████| 347/347 [02:58<00:00,  1.95it/s, loss_d=0.322, loss_g=10.4]


epoch: 9


100%|████████████| 347/347 [02:58<00:00,  1.95it/s, loss_d=0.00693, loss_g=11.5]


epoch: 10


100%|███████████| 347/347 [02:58<00:00,  1.95it/s, loss_d=0.000144, loss_g=14.8]


epoch: 11


100%|██████████████| 347/347 [02:58<00:00,  1.94it/s, loss_d=0.605, loss_g=1.69]


epoch: 12


100%|███████████| 347/347 [02:58<00:00,  1.94it/s, loss_d=0.000299, loss_g=12.9]


epoch: 13


100%|████████████████| 347/347 [02:57<00:00,  1.95it/s, loss_d=0.046, loss_g=12]


epoch: 14


100%|████████████| 347/347 [02:57<00:00,  1.95it/s, loss_d=0.00164, loss_g=10.2]


epoch: 15


100%|██████████████| 347/347 [02:58<00:00,  1.95it/s, loss_d=0.00754, loss_g=13]


epoch: 16


100%|██████████████| 347/347 [03:01<00:00,  1.91it/s, loss_d=0.886, loss_g=4.11]


epoch: 17


100%|███████████████| 347/347 [03:14<00:00,  1.79it/s, loss_d=1.81, loss_g=1.98]


epoch: 18


100%|████████████| 347/347 [03:13<00:00,  1.79it/s, loss_d=0.00175, loss_g=11.3]


epoch: 19


100%|█████████████| 347/347 [03:12<00:00,  1.81it/s, loss_d=3.67, loss_g=0.0372]


epoch: 20


100%|███████████| 347/347 [03:11<00:00,  1.81it/s, loss_d=3.86, loss_g=0.000255]


epoch: 21


100%|████████████████| 347/347 [03:11<00:00,  1.81it/s, loss_d=0.104, loss_g=16]


epoch: 22


100%|████████████| 347/347 [03:12<00:00,  1.80it/s, loss_d=6.26, loss_g=1.31e-5]


epoch: 23


100%|█████████████| 347/347 [03:12<00:00,  1.81it/s, loss_d=0.000111, loss_g=15]


epoch: 24


100%|█████████████| 347/347 [03:10<00:00,  1.82it/s, loss_d=0.0193, loss_g=16.1]


epoch: 25


100%|██████████████| 347/347 [03:11<00:00,  1.81it/s, loss_d=0.501, loss_g=6.03]


epoch: 26


100%|█████████████| 347/347 [03:10<00:00,  1.82it/s, loss_d=0.0014, loss_g=11.5]


epoch: 27


100%|███████████| 347/347 [03:12<00:00,  1.80it/s, loss_d=0.000227, loss_g=15.5]


epoch: 28


100%|███████████| 347/347 [03:10<00:00,  1.82it/s, loss_d=0.000192, loss_g=10.5]


epoch: 29


100%|███████████████| 347/347 [03:11<00:00,  1.81it/s, loss_d=11.8, loss_g=5.13]


epoch: 30


100%|████████████| 347/347 [03:11<00:00,  1.81it/s, loss_d=0.00188, loss_g=9.31]


epoch: 31


100%|████████████| 347/347 [03:12<00:00,  1.80it/s, loss_d=0.00126, loss_g=13.9]


epoch: 32


100%|██████████████| 347/347 [03:00<00:00,  1.92it/s, loss_d=0.025, loss_g=12.5]


epoch: 33


100%|███████████████| 347/347 [03:03<00:00,  1.89it/s, loss_d=1.15, loss_g=3.27]


epoch: 34


100%|█████████████| 347/347 [02:59<00:00,  1.94it/s, loss_d=0.0781, loss_g=7.09]


epoch: 35


100%|███████████████| 347/347 [02:59<00:00,  1.93it/s, loss_d=0.0488, loss_g=16]


epoch: 36


100%|████████████| 347/347 [02:58<00:00,  1.94it/s, loss_d=3.07e-6, loss_g=16.1]


epoch: 37


100%|█████████████| 347/347 [03:09<00:00,  1.83it/s, loss_d=0.0226, loss_g=11.2]


epoch: 38


100%|████████████| 347/347 [02:58<00:00,  1.94it/s, loss_d=0.00639, loss_g=14.1]


epoch: 39


100%|██████████████| 347/347 [02:58<00:00,  1.94it/s, loss_d=3.2, loss_g=0.0681]


epoch: 40


100%|██████████████| 347/347 [02:58<00:00,  1.94it/s, loss_d=0.241, loss_g=11.7]


epoch: 41


100%|██████████████| 347/347 [02:58<00:00,  1.94it/s, loss_d=6.34, loss_g=0.943]


epoch: 42


100%|████████████| 347/347 [02:57<00:00,  1.95it/s, loss_d=2.35e-5, loss_g=12.3]


epoch: 43


100%|████████████| 347/347 [02:57<00:00,  1.95it/s, loss_d=0.00443, loss_g=13.5]


epoch: 44


100%|██████████████| 347/347 [02:58<00:00,  1.94it/s, loss_d=0.136, loss_g=13.2]


epoch: 45


100%|████████████| 347/347 [02:58<00:00,  1.95it/s, loss_d=0.00023, loss_g=14.9]


epoch: 46


100%|█████████████| 347/347 [02:57<00:00,  1.95it/s, loss_d=0.0986, loss_g=11.4]


epoch: 47


100%|████████████| 347/347 [02:57<00:00,  1.95it/s, loss_d=8.34e-7, loss_g=14.3]


epoch: 48


100%|█████████████| 347/347 [02:57<00:00,  1.96it/s, loss_d=0.0046, loss_g=8.96]


epoch: 49


100%|██████████████| 347/347 [02:57<00:00,  1.95it/s, loss_d=0.501, loss_g=4.29]


epoch: 50


100%|████████████| 347/347 [02:57<00:00,  1.95it/s, loss_d=4.08e-6, loss_g=17.1]


epoch: 51


100%|██████████████| 347/347 [02:58<00:00,  1.95it/s, loss_d=0.00338, loss_g=17]


epoch: 52


100%|████████████| 347/347 [02:57<00:00,  1.96it/s, loss_d=3.76e-5, loss_g=12.9]


epoch: 53


100%|████████████| 347/347 [02:58<00:00,  1.94it/s, loss_d=5.59e-5, loss_g=12.1]


epoch: 54


100%|████████████| 347/347 [02:58<00:00,  1.94it/s, loss_d=2.98e-8, loss_g=19.4]


epoch: 55


100%|███████████████| 347/347 [02:58<00:00,  1.94it/s, loss_d=1.77, loss_g=40.9]


epoch: 56


100%|█████████████| 347/347 [02:57<00:00,  1.96it/s, loss_d=0.0167, loss_g=9.97]


epoch: 57


100%|████████████| 347/347 [02:58<00:00,  1.94it/s, loss_d=8.05e-7, loss_g=17.1]


epoch: 58


100%|█████████████| 347/347 [02:58<00:00,  1.94it/s, loss_d=4.2e-6, loss_g=14.6]


epoch: 59


100%|███████████████| 347/347 [02:58<00:00,  1.95it/s, loss_d=0.23, loss_g=3.78]


epoch: 60


100%|████████████| 347/347 [02:59<00:00,  1.93it/s, loss_d=1.04e-6, loss_g=19.9]


epoch: 61


100%|████████████████| 347/347 [03:00<00:00,  1.92it/s, loss_d=3.22, loss_g=0.6]


epoch: 62


100%|████████████| 347/347 [02:57<00:00,  1.96it/s, loss_d=8.28e-6, loss_g=12.3]


epoch: 63


100%|██████████████| 347/347 [02:57<00:00,  1.96it/s, loss_d=2.21, loss_g=0.013]


epoch: 64


100%|███████████████| 347/347 [02:55<00:00,  1.97it/s, loss_d=4.57, loss_g=0.92]


epoch: 65


100%|███████████████| 347/347 [02:55<00:00,  1.97it/s, loss_d=3.76, loss_g=1.24]


epoch: 66


100%|██████████████| 347/347 [02:56<00:00,  1.97it/s, loss_d=1.49e-7, loss_g=16]


epoch: 67


100%|███████████████| 347/347 [02:56<00:00,  1.96it/s, loss_d=1.16, loss_g=2.94]


epoch: 68


100%|███████████████| 347/347 [02:56<00:00,  1.97it/s, loss_d=1.15, loss_g=18.4]


epoch: 69


100%|████████████████| 347/347 [02:57<00:00,  1.96it/s, loss_d=0.814, loss_g=14]


epoch: 70


100%|█████████████| 347/347 [02:56<00:00,  1.97it/s, loss_d=0.0832, loss_g=16.9]


epoch: 71


 48%|█████▎     | 168/347 [01:26<01:31,  1.95it/s, loss_d=0.000188, loss_g=13.3]


KeyboardInterrupt: 

In [22]:
# generator = Generator(ngpu).to(device)
# generator.load_state_dict(torch.load('generator.pt'))
# generator.eval()
# model_for_sampling = timm.create_model('resnet18', pretrained=True, num_classes=2).to(device)
# optimizer = torch.optim.Adam(model_for_sampling.parameters(), lr=lr)

# for epoch in range(5):
#     print(f'epoch: {epoch + 1}')
#     with tqdm(dataloader, total=len(dataloader), position=0, leave=True) as pbar:
#         for i, data in enumerate(pbar):
#             model_for_sampling.zero_grad()

#             real_cpu = data.to(device)

#             b_size = real_cpu.size(0)
#             label = torch.LongTensor([1 for j in range(b_size)]).to(device)

#             output = model_for_sampling(real_cpu)

#             loss_d_real = criterion(output, label)

#             loss_d_real.backward()
#             D_x = output.mean().item()
#             noise = torch.randn(b_size, nz, 1, 1, device=device)
#             fake = generator(noise)
#             label.fill_(fake_label)
#             output = model_for_sampling(fake.detach())

#             loss_d_fake = criterion(output, label)
            
#             loss_d_fake.backward()
#             D_G_z1 = output.mean().item()
#             loss_d = loss_d_real + loss_d_fake

#             optimizer.step()

#             pbar.set_postfix(
#                 OrderedDict(loss_d=loss_d.item())
#             )
        
#         iters += 1
#     torch.save(model_for_sampling.state_dict(), 'sep.pt')

epoch: 1


100%|██████████████████████████| 347/347 [02:23<00:00,  2.42it/s, loss_d=0.0156]


epoch: 2


100%|█████████████████████████| 347/347 [02:28<00:00,  2.33it/s, loss_d=0.00407]


epoch: 3


100%|████████████████████████| 347/347 [02:29<00:00,  2.32it/s, loss_d=0.000308]


epoch: 4


100%|████████████████████████| 347/347 [02:30<00:00,  2.30it/s, loss_d=0.000457]


epoch: 5


100%|████████████████████████| 347/347 [02:30<00:00,  2.31it/s, loss_d=0.000145]


In [45]:
generator = Generator(ngpu).to(device)
generator.load_state_dict(torch.load('generator.pt'))

discriminator = Discriminator(ngpu).to(device)
discriminator.load_state_dict(torch.load('discriminator.pt'))
# discriminator = timm.create_model('resnet18', num_classes=2).to(device)
# discriminator.load_state_dict(torch.load('sep.pt'))

discriminator.eval()
generator.eval()

Generator(
  (main): Sequential(
    (0): ConvTranspose2d(500, 2048, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(2048, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Ba

In [50]:
bs = 32
steps = 1000

if not Path('generated').exists():
    os.mkdir('generated')

    
cnt = 0
for step in tqdm(range(steps)):
    with torch.no_grad():
        noise = torch.randn(bs, nz, 1, 1, device=device)
        images = generator(noise)
        probs = discriminator(images)
        
        probs = probs.softmax(dim=1)
#         labels = probs.argmax(dim=1).detach().cpu().numpy()
        for i in range(bs):
            if probs[i, 1] >= 0.6:
                save_image(images[i, ...].permute(1, 2, 0).detach().cpu().numpy(), 
                           os.path.join('generated', str(cnt) + '.png'))
                cnt += 1

100%|███████████████████████████████████████| 1000/1000 [00:46<00:00, 21.72it/s]
