# Загрузка данных

In [None]:
!pip install pytorch-lightning
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from torch.utils.data import DataLoader, random_split
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score

random_seed = 42
torch.manual_seed(random_seed)

BATCH_SIZE=24
AVAIL_GPUS = min(1, torch.cuda.device_count())
NUM_WORKERS=int(os.cpu_count() / 2)

In [None]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)

In [None]:
images = np.load("/content/gdrive/MyDrive/small_data.npy")

In [None]:
labels = np.genfromtxt('/content/gdrive/MyDrive/all_target.csv', delimiter=',')
labels = labels[1:]
label = np.zeros(81*10)
labels = np.array([1 if i < 3 else 0 for i in labels]) #1 - healthy | 0 - problems
for i in range(810):
  label[i] = labels[i//10]
label = label.astype(int,copy=False)

NameError: ignored

In [None]:
images.shape
images = np.resize(images, (810, 512, 512))

In [None]:
train_images = images[:700, :, :]
train_labels = label[:700]
test_images = images[700 :, :, :]
test_labels = label[700:]

In [None]:
import albumentations as A
import albumentations.pytorch


class SegmentationDataset(torch.utils.data.Dataset):
    def __init__(self, transforms):
        self.images = train_images
        self.labels = train_labels
        self.transforms = transforms

    def __getitem__(self, idx):
        img = self.images[idx]
        img = img[130:430, 100:400]
        augmentations = self.transforms(image=img)
        image = augmentations["image"]
        return image, (self.labels[idx].astype(np.float32))

    def __len__(self):
        return self.images.shape[0]

In [None]:
class SegmentationTestDataset(torch.utils.data.Dataset):
    def __init__(self, transforms):
        self.images = test_images
        self.labels = test_labels
        self.transforms = transforms

    def __getitem__(self, idx):
        img = self.images[idx]
        img = img[130:430, 100:400]
        augmentations = self.transforms(image=img)
        image = augmentations["image"]
        return image, (self.labels[idx].astype(int))

    def __len__(self):
        return self.images.shape[0]

In [None]:
test_data = SegmentationTestDataset(
    transforms=A.Compose([A.Resize(height=128, width=128),A.Normalize((0.5,), (0.5,)), A.pytorch.transforms.ToTensorV2()]),
)

test_loader = torch.utils.data.DataLoader(
    dataset=test_data, batch_size=110, pin_memory=True, num_workers=3, shuffle = False
)

In [None]:
len(test_data)

In [None]:
data = SegmentationDataset(
    transforms=A.Compose([A.Resize(height=128, width=128),A.Normalize((0.5,), (0.5,)), A.pytorch.transforms.ToTensorV2()]),
)

train_loader = torch.utils.data.DataLoader(
    dataset=data, batch_size=BATCH_SIZE, pin_memory=True, num_workers=3, shuffle = True
)

In [None]:
def im_show(img_list) -> None:

    fig, axes = plt.subplots(len(img_list), 1, figsize=(16, 16))
    fig.tight_layout()

    for (idx, sample) in enumerate(img_list):
        a = axes[idx].imshow(sample.squeeze())
        for ax in axes:
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)

    plt.show()


img_list = []
for i in range(4):
    img = data[i][0]
    img_list.append(img)

im_show(img_list)

# Создание модели данных

In [None]:
class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super().__init__()

        self.model = nn.Sequential(
            nn.Conv2d(1, 48, 4, 2, 1),
            nn.BatchNorm2d(48),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(48, 4, 4, 3),
            nn.BatchNorm2d(4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),
            nn.Linear(4*21*21, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img #.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity

In [None]:
print(Discriminator((1, 128, 128)))

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape):
        super().__init__()
        self.img_shape = img_shape

        def block(in_ch, out_ch, kernel_size, stride, padding, normalize=True, relu = True):
            layers = [nn.ConvTranspose2d(in_channels = in_ch, out_channels = out_ch,
                                         kernel_size = kernel_size, stride = stride,
                                         padding = padding)]
            if normalize:
                layers.append(nn.BatchNorm2d(out_ch, 0.8))
            if relu: 
              layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim, 128, 8, 2, 0),
            *block(128, 256, 4, 3, 2),
            *block(256, 48, 4, 3, 0),
            *block(48, 1, 4, 2, 1, False, False),
            nn.Tanh(),
        )

    def forward(self, z):
        inp = z.view(z.size(0), z.size(1), 1, 1)
        img = self.model(inp)
        return img

In [None]:
print(Generator(100, (1, 128, 128)))

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
discr_loss_trace = []
gen_loss_trace = []
test_acc = []

In [None]:
S1 = [[1, 2, 0],
 [3, 4, 1],
 [5, 6, 2],
 [4, 6, 7],
 [9, 10, 8],
 [10, 11, 13],
 [12, 9, 14],
 [14, 13, 15]]
 S2 = [[0, 15], [12, 3], [11, 5], [7, 8]]

In [None]:
import time

In [None]:
class Fitter(object):
    def __init__(
        self,
        generator,
        discriminator,
        batch_size = 32,
        n_epochs = 10,
        latent_dim = 1,
        lr = 0.0001,
        n_critic=5,
    ):

        self.generator = generator
        self.discriminator = discriminator
        self.batch_size = batch_size
        self.n_epochs = n_epochs
        self.latent_dim = latent_dim
        self.lr = lr
        self.n_critic = n_critic

        self.opt_gen = torch.optim.Adam(self.generator.parameters(), lr=self.lr)
        self.opt_disc = torch.optim.Adam(self.discriminator.parameters(), lr=self.lr)

        self.generator.to(DEVICE)
        self.discriminator.to(DEVICE)

    def fit(self, train_dataloader):

        # Turn on training
        self.generator.train(True)
        self.discriminator.train(True)

        self.loss_history = []

        # Fit GAN
        i = 0
        for epoch in tqdm(range(self.n_epochs)):
            start = time.time()
            for real_batch in train_dataloader:
                real_objects, real_labels = real_batch
                real_labels = real_labels.to(DEVICE)
                real_objects = real_objects.to(DEVICE)
                real_labels = torch.unsqueeze(real_labels, 1)
                
                num_objects = real_objects.shape[0]
                z = torch.normal(0, 1, (len(real_objects), self.latent_dim))
                gen_objects = self.generator(z)
                
                real_objects_scores = self.discriminator(real_objects)
                gen_objects_scores = self.discriminator(gen_objects)
                
                for p in self.discriminator.parameters():
                    p.data.clamp_(-0.01, 0.01)
                if i % (self.n_critic + 1) == 0:
                    self.opt_gen.zero_grad()
                    valid = torch.ones(real_objects.size(0), 1)
                    gen_loss = F.binary_cross_entropy(self.discriminator(self.generator(z)), valid)
                    gen_loss.backward()
                    self.opt_gen.step()
                    gen_loss_trace.append(gen_loss.item())
                else:
                    self.opt_disc.zero_grad()
                    #valid = torch.ones(real_objects.size(0), 1)
                    real_loss = F.binary_cross_entropy(self.discriminator(real_objects), real_labels)
                    fake = torch.zeros(real_objects.size(0), 1)
                    fake_loss = F.binary_cross_entropy(self.discriminator(self.generator(z).detach()), fake)
                    discr_loss = (real_loss + fake_loss) / 2
                    discr_loss.backward()
                    self.opt_disc.step()
                    discr_loss_trace.append(discr_loss.item())
                i += 1


            # caiculate and store loss after an epoch
            #Z_noise = torch.normal(0, 1, (len(X_real), self.latent_dim))
            #X_fake = self.generator(Z_noise)
            #loss_epoch = torch.mean(self.discriminator(X_real)) - torch.mean(
             #   self.discriminator(X_fake)
            #)
            #self.loss_history.append(loss_epoch.detach().cpu())

            # Валидация
            self.generator.train(False)
            self.discriminator.train(False)

            for valid_batch in test_loader:
              valid_objects, valid_labels = valid_batch
              valid_labels = valid_labels.to(DEVICE)
              valid_objects = valid_objects.to(DEVICE)
              valid_labels = torch.unsqueeze(valid_labels, 1)
              z = torch.normal(0, 1, (30, self.latent_dim))
              gen_objects = self.generator(z)
              real_objects_scores = self.discriminator(valid_objects)
              gen_objects_scores = self.discriminator(gen_objects)
              #print(real_objects_scores.size(), gen_objects_scores.size(), valid_labels.size())
              y_pr = torch.cat([real_objects_scores, gen_objects_scores], 0)
              y_tr = torch.cat([valid_labels, torch.zeros((30, 1))], 0)
              y_pred = torch.squeeze(y_pr, 1)
              y_true = torch.squeeze(y_tr, 1)
              y_pred = torch.tensor([0 if i < 0.5 else 1 for i in y_pred.detach().numpy()])
              acc = accuracy_score(y_pred, y_true)
              test_acc.append(acc)

            if epoch % 50 == 0:
              path_gen = "model_gen_"+ str(epoch)
              torch.save(self.generator, path_gen)
              path_disc = "model_disc_" + str(epoch)
              torch.save(self.discriminator, path_disc)
            self.generator.train(True)
            self.discriminator.train(True)
            print ('Time for epoch {} is {} sec'.format(epoch, time.time()-start))

            
        # Turn off training
        self.generator.train(False)
        self.discriminator.train(False)

In [None]:
latent_dim = 100
data_shape = (1, 128, 128)
generator = Generator(latent_dim=latent_dim, img_shape=data_shape)
discriminator = Discriminator(img_shape=data_shape)

fitter = Fitter(
    generator,
    discriminator,
    batch_size=BATCH_SIZE,
    n_epochs=500,
    latent_dim=latent_dim,
    lr=0.0001,
    n_critic=3,
)
fitter.fit(train_loader)

In [None]:
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.xlabel("Iteration")
plt.ylabel("Generator loss")
plt.plot(range(len(gen_loss_trace)), gen_loss_trace)

plt.subplot(1, 3, 2)
plt.xlabel("Iteration")
plt.ylabel("Discriminator loss")
plt.plot(range(len(discr_loss_trace)), discr_loss_trace, color="orange")

In [None]:
plt.figure(figsize=(8, 8))
plt.xlabel("Epoch")
plt.ylabel("Test accuracy")
plt.plot(range(len(test_acc)), test_acc)

In [None]:
z = torch.normal(0, 1, (2, 100))
images = generator(z).detach().numpy()
im_show(images)