In [1]:

import math
import copy
from pathlib import Path
from random import random
from functools import partial
from collections import namedtuple
from multiprocessing import cpu_count

import torch
from torch import nn, einsum
import torch.nn.functional as F

import torchvision
from torch.optim import Adam
from torchvision import transforms as T, utils
from torch.utils.data import DataLoader

from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange

from tqdm.auto import tqdm
from ema_pytorch import EMA

from accelerate import Accelerator

import os

from denoising_diffusion_pytorch.fid_evaluation import FIDEvaluation

from classifier_free_guidance import Unet, GaussianDiffusion

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def get_data(image_size, folder, batch_size):
    transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize(image_size),  # args.image_size + 1/4 *args.image_size
        torchvision.transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0)),
        torchvision.transforms.ToTensor()
    ])
    dataset = torchvision.datasets.ImageFolder(folder, transform=transforms)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    return dataloader

In [3]:
def cycle(dl):
    while True:
        for data in dl:
            yield data

In [4]:
accelerator = Accelerator(
    split_batches = True,
    mixed_precision = 'fp16'
)

In [5]:
dl = get_data(256, '/home/yoos-bii/Desktop/data_tct/val', 1)
dl

<torch.utils.data.dataloader.DataLoader at 0x7f91480ccaf0>

In [6]:
data = next(iter(dl))
data

[tensor([[[[0.9451, 0.9490, 0.9412,  ..., 0.6549, 0.6471, 0.6863],
           [0.9373, 0.9412, 0.9451,  ..., 0.6588, 0.6392, 0.6627],
           [0.9412, 0.9373, 0.9451,  ..., 0.7020, 0.6863, 0.6902],
           ...,
           [0.9412, 0.9412, 0.9451,  ..., 0.9529, 0.9608, 0.9647],
           [0.9451, 0.9451, 0.9412,  ..., 0.9569, 0.9490, 0.9451],
           [0.9412, 0.9451, 0.9490,  ..., 0.9569, 0.9490, 0.9451]],
 
          [[0.9216, 0.9255, 0.9176,  ..., 0.7961, 0.7804, 0.8157],
           [0.9137, 0.9176, 0.9216,  ..., 0.7922, 0.7686, 0.7882],
           [0.9176, 0.9137, 0.9216,  ..., 0.8314, 0.8118, 0.8118],
           ...,
           [0.9176, 0.9176, 0.9216,  ..., 0.9294, 0.9373, 0.9412],
           [0.9216, 0.9216, 0.9176,  ..., 0.9333, 0.9255, 0.9216],
           [0.9176, 0.9216, 0.9255,  ..., 0.9333, 0.9255, 0.9216]],
 
          [[0.9373, 0.9412, 0.9333,  ..., 0.8471, 0.8314, 0.8588],
           [0.9294, 0.9333, 0.9373,  ..., 0.8510, 0.8196, 0.8314],
           [0.9333, 0.92

In [7]:
dl = accelerator.prepare(dl)
dl = cycle(dl)

In [8]:
model = Unet(
        dim = 64,
        dim_mults = (1, 2, 4, 8),
        num_classes = 15,
        cond_drop_prob = 0.5
    ).cuda()

diffusion = GaussianDiffusion(
    model,
    image_size = 256,
    timesteps = 1000,
    sampling_timesteps=250,
    loss_type='l2'
).cuda()

In [9]:
opt = Adam(diffusion.parameters(), lr = 1e-4, betas = (0.9, 0.99))

In [12]:
data = torch.load(os.path.join('/home/yoos-bii/Desktop/workspace/diffusion_digital_pathology/Checkpoint-diffusion/results_cond_512TO256_GTEX', f'model-{30}.pt'), map_location=accelerator.device)

# print(data)

model = accelerator.unwrap_model(diffusion)
model.load_state_dict(data['model'])

opt = opt.load_state_dict(data['opt'])



In [13]:
ema = EMA(diffusion, beta = (0.9, 0.99), update_every = 0.995)
ema = ema.to(accelerator.device)

ema.load_state_dict(data['ema'])

<All keys matched successfully>

In [14]:
model, opt = accelerator.prepare(diffusion, opt)

In [15]:
# model, opt 
print(opt)

None


In [16]:
def exists(x):
    return x is not None

def load(opt, accelerator, milestone, model):
    accelerator = accelerator
    device = accelerator.device

    data = torch.load(os.path.join('/home/yoos-bii/Desktop/workspace/diffusion_digital_pathology/Checkpoint-diffusion/results_cond_512TO256_GTEX', f'model-{milestone}.pt'), map_location=device)

    model = accelerator.unwrap_model(model)
    model.load_state_dict(data['model'])

    step = data['step']
    opt.load_state_dict(data['opt'])

    if accelerator.is_main_process:
        return ema.load_state_dict(data['ema'])


    if 'version' in data:
        print(f"loading from version {data['version']}")

    if exists(accelerator.scaler) and exists(data['scaler']):
        accelerator.scaler.load_state_dict(data['scaler'])

In [43]:
load_model = load(opt, accelerator, '30', model)
load_model

<All keys matched successfully>

In [21]:
fid_score = FIDEvaluation(
    batch_size = 1, 
    dl=dl, 
    sampler=ema.ema_model, 
    # sampler=load_model,
    channels=3,
    accelerator=accelerator,
    # stats_dir='/home/yoos-bii/Desktop/workspace/diffusion_digital_pathology/Checkpoint-diffusion/output_fid',
    device=accelerator.device, 
    num_fid_samples=50000,
    inception_block_idx=2048
)

In [22]:
get_fid_score = fid_score.fid_score()


Stacking Inception features for 50000 samples from the real dataset.


  0%|          | 0/50000 [00:00<?, ?it/s]


AttributeError: 'list' object has no attribute 'to'