# DDPM

## Set global variable

In [6]:
# Path to dataset
DATASET = "dataset/cats/"
# DATASET = "dataset/lfw/"
DATASET_NAME = DATASET.rstrip('/').split('/')[-1]

# Hyperparameter for training
IMG_SIZE = 64
CHANNELS_IMG = 3
BATCH_SIZE = 16

LEARNING_RATE = 1e-4  #0.0001
# LEARNING_RATE = 1e-5  #0.00001
NUM_EPOCHS = 100

print(DATASET_NAME)

cats


## Count and set number of image variable

In [7]:
import os
import glob

def count_image_files(folder_path, extensions=['.jpg', '.jpeg', '.png', '.gif']):
    count = 0
    for extension in extensions:
        search_pattern = os.path.join(folder_path, '*' + extension)
        count += len(glob.glob(search_pattern))
    return count


In [8]:
NUM_IMAGES = count_image_files(DATASET)

print(f"Number of image files in {DATASET}: {NUM_IMAGES}")

Number of image files in dataset/cats/: 15747


## Set agnostic code

In [4]:
import torch

# agnostic code to detect the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


## Load images to tensor and normalize images

In [5]:
import torchvision.transforms as transforms
from PIL import Image

# Get the list of image filenames
image_filenames = os.listdir(DATASET)

# Define the image transformations
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    # transforms.RandomHorizontalFlip(),
    transforms.ToTensor(), # convert PIL image to tensor and scales data into [0,1] 
    # transforms.Lambda(lambda t: (t * 2) - 1) # Scale between [-1, 1] 
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), # Scale between [-1, 1] by (input[channel] - mean[channel]) / std[channel]
])

# Create a list to store the transformed images
transformed_images = []

# Iterate over the first num_images filenames and transform the corresponding images
for i, filename in enumerate(image_filenames[:NUM_IMAGES]):
    # Load the image
    img_path = os.path.join(DATASET, filename)
    image = Image.open(img_path)

    # Apply the transformations
    transformed_image = transform(image)

    # Append the transformed image to the list
    transformed_images.append(transformed_image)

# Convert the list of transformed images to a PyTorch tensor
transformed_images = torch.stack(transformed_images)

print(f'Loaded data: {transformed_images.shape}')

Loaded data: torch.Size([12876, 3, 64, 64])


## Separate batch size 

In [6]:
from torch.utils.data import DataLoader

data_loader = DataLoader(transformed_images, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

# print first batch.shape
data_iter = iter(data_loader)
print(next(data_iter).shape)

torch.Size([16, 3, 64, 64])


## Train model and save weight and log 

## training process
parameter base on DDPM and unet papers https://arxiv.org/pdf/1505.04597v1.pdf

In [6]:
import time
from modules.ddpm import Diffusion, UNet

import torch.optim as optim
import torch.nn as nn
import torchvision

from torch.utils.tensorboard import SummaryWriter

model = UNet().to(device)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
mse = nn.MSELoss()
diffusion = Diffusion(img_size=IMG_SIZE, device=device)
writer = SummaryWriter(os.path.join(f"logs/{DATASET_NAME}","DDPM"))

l = len(data_loader)

time_use = 0

for epoch in range(NUM_EPOCHS):

    lossMean = 0

    # use time for time measurement
    start_time = time.time()

    model.train()
    
    for batch_idx, images in enumerate(data_loader):
        images = images.to(device)
        t = diffusion.sample_timesteps(images.shape[0]).to(device)
        x_t, noise = diffusion.noise_images(images, t)
        predicted_noise = model(x_t, t)
        loss = mse(noise, predicted_noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # sum lose
        lossMean += loss

        # add log at first 
        if epoch == 0 and batch_idx == 0:
            writer.add_scalar("Loss/MSE", lossMean, global_step=epoch)
            writer.add_scalar("traing time", time_use, epoch)

    end_time = time.time()
    epoch_time = end_time - start_time

    time_use += epoch_time 

    epoch_index = epoch+1

    # calculate mean value
    lossMean = lossMean / len(data_loader.dataset)

    print(
        f"Epoch [{epoch_index}/{NUM_EPOCHS}] Batch {batch_idx+1}/{len(data_loader)} Using Time: {epoch_time:.4f}\
            Loss: {lossMean:.4f}"
    )

    # tensorboard
    writer.add_scalar("Loss/MSE", lossMean, global_step=epoch+1)
    writer.add_scalar("traing time", time_use, epoch_index)

    weight_path = os.path.join("weights", f'{DATASET_NAME}', "DDPM")
    os.makedirs(weight_path, exist_ok=True)
    torch.save(model.state_dict(), os.path.join("weights", f'{DATASET_NAME}', "DDPM", f"{epoch_index}.pt"))

    model.eval()
    with torch.no_grad():
        x = diffusion.sample(model, n=32).type(dtype=torch.float32)
        img_grid = torchvision.utils.make_grid(x[:32], normalize=True)
        writer.add_image("All/Gen", img_grid, global_step=epoch_index)


  return F.conv2d(input, weight, bias, self.stride,


Epoch [1/100] Batch 804/804 Using Time: 206.7065            Loss: 0.0072


999it [02:43,  6.10it/s]


Epoch [2/100] Batch 804/804 Using Time: 212.6587            Loss: 0.0025


999it [02:44,  6.07it/s]


Epoch [3/100] Batch 804/804 Using Time: 212.7700            Loss: 0.0021


999it [02:45,  6.04it/s]


Epoch [4/100] Batch 804/804 Using Time: 213.6989            Loss: 0.0018


999it [02:45,  6.02it/s]


Epoch [5/100] Batch 804/804 Using Time: 213.9060            Loss: 0.0017


999it [02:46,  6.00it/s]


Epoch [6/100] Batch 804/804 Using Time: 213.9101            Loss: 0.0016


999it [02:45,  6.02it/s]


Epoch [7/100] Batch 804/804 Using Time: 213.5752            Loss: 0.0014


999it [02:46,  6.00it/s]


Epoch [8/100] Batch 804/804 Using Time: 215.2891            Loss: 0.0015


999it [02:47,  5.96it/s]


Epoch [9/100] Batch 804/804 Using Time: 215.9116            Loss: 0.0013


999it [02:48,  5.94it/s]


Epoch [10/100] Batch 804/804 Using Time: 215.8017            Loss: 0.0013


999it [02:47,  5.96it/s]


Epoch [11/100] Batch 804/804 Using Time: 215.4030            Loss: 0.0013


999it [02:47,  5.96it/s]


Epoch [12/100] Batch 804/804 Using Time: 214.9980            Loss: 0.0012


999it [02:47,  5.95it/s]


Epoch [13/100] Batch 804/804 Using Time: 215.4285            Loss: 0.0012


999it [02:47,  5.97it/s]


Epoch [14/100] Batch 804/804 Using Time: 215.0088            Loss: 0.0012


999it [02:47,  5.97it/s]


Epoch [15/100] Batch 804/804 Using Time: 213.9613            Loss: 0.0012


999it [02:47,  5.98it/s]


Epoch [16/100] Batch 804/804 Using Time: 215.4945            Loss: 0.0012


999it [02:48,  5.94it/s]


Epoch [17/100] Batch 804/804 Using Time: 215.5220            Loss: 0.0012


999it [02:48,  5.94it/s]


Epoch [18/100] Batch 804/804 Using Time: 215.2501            Loss: 0.0012


999it [02:47,  5.95it/s]


Epoch [19/100] Batch 804/804 Using Time: 215.0584            Loss: 0.0011


999it [02:47,  5.96it/s]


Epoch [20/100] Batch 804/804 Using Time: 215.4775            Loss: 0.0011


999it [02:47,  5.96it/s]


Epoch [21/100] Batch 804/804 Using Time: 215.1458            Loss: 0.0011


999it [02:47,  5.96it/s]


Epoch [22/100] Batch 804/804 Using Time: 215.7161            Loss: 0.0011


999it [02:48,  5.93it/s]


Epoch [23/100] Batch 804/804 Using Time: 215.8502            Loss: 0.0011


999it [02:47,  5.96it/s]


Epoch [24/100] Batch 804/804 Using Time: 215.2973            Loss: 0.0011


999it [02:47,  5.97it/s]


Epoch [25/100] Batch 804/804 Using Time: 212.0260            Loss: 0.0011


999it [02:44,  6.06it/s]


Epoch [26/100] Batch 804/804 Using Time: 211.7509            Loss: 0.0010


999it [02:43,  6.09it/s]


Epoch [27/100] Batch 804/804 Using Time: 211.9422            Loss: 0.0011


999it [02:44,  6.07it/s]


Epoch [28/100] Batch 804/804 Using Time: 211.7682            Loss: 0.0010


999it [02:44,  6.06it/s]


Epoch [29/100] Batch 804/804 Using Time: 212.0861            Loss: 0.0011


999it [02:44,  6.07it/s]


Epoch [30/100] Batch 804/804 Using Time: 211.4202            Loss: 0.0010


999it [02:44,  6.07it/s]


Epoch [31/100] Batch 804/804 Using Time: 211.8441            Loss: 0.0010


999it [02:44,  6.06it/s]


Epoch [32/100] Batch 804/804 Using Time: 212.1057            Loss: 0.0010


999it [02:44,  6.05it/s]


Epoch [33/100] Batch 804/804 Using Time: 212.0261            Loss: 0.0011


999it [02:44,  6.09it/s]


Epoch [34/100] Batch 804/804 Using Time: 211.9107            Loss: 0.0010


999it [02:44,  6.08it/s]


Epoch [35/100] Batch 804/804 Using Time: 210.9532            Loss: 0.0010


999it [02:43,  6.12it/s]


Epoch [36/100] Batch 804/804 Using Time: 210.6350            Loss: 0.0010


999it [02:44,  6.09it/s]


Epoch [37/100] Batch 804/804 Using Time: 211.5892            Loss: 0.0010


999it [02:44,  6.07it/s]


Epoch [38/100] Batch 804/804 Using Time: 211.1664            Loss: 0.0010


999it [02:43,  6.10it/s]


Epoch [39/100] Batch 804/804 Using Time: 210.8496            Loss: 0.0010


999it [02:43,  6.10it/s]


Epoch [40/100] Batch 804/804 Using Time: 211.2310            Loss: 0.0010


999it [02:43,  6.10it/s]


Epoch [41/100] Batch 804/804 Using Time: 210.9908            Loss: 0.0010


999it [02:44,  6.09it/s]


Epoch [42/100] Batch 804/804 Using Time: 210.8354            Loss: 0.0010


999it [02:43,  6.10it/s]


Epoch [43/100] Batch 804/804 Using Time: 210.7328            Loss: 0.0010


999it [02:43,  6.12it/s]


Epoch [44/100] Batch 804/804 Using Time: 210.1425            Loss: 0.0010


999it [02:43,  6.10it/s]


Epoch [45/100] Batch 804/804 Using Time: 210.3559            Loss: 0.0010


999it [02:43,  6.11it/s]


Epoch [46/100] Batch 804/804 Using Time: 210.2317            Loss: 0.0010


999it [02:43,  6.12it/s]


Epoch [47/100] Batch 804/804 Using Time: 210.3400            Loss: 0.0010


999it [02:43,  6.13it/s]


Epoch [48/100] Batch 804/804 Using Time: 209.8910            Loss: 0.0010


999it [02:43,  6.12it/s]


Epoch [49/100] Batch 804/804 Using Time: 210.3391            Loss: 0.0010


999it [02:43,  6.11it/s]


Epoch [50/100] Batch 804/804 Using Time: 210.4483            Loss: 0.0010


999it [02:42,  6.13it/s]


Epoch [51/100] Batch 804/804 Using Time: 210.2085            Loss: 0.0010


999it [02:42,  6.14it/s]


Epoch [52/100] Batch 804/804 Using Time: 209.7093            Loss: 0.0010


999it [02:42,  6.16it/s]


Epoch [53/100] Batch 804/804 Using Time: 209.2183            Loss: 0.0010


999it [02:42,  6.14it/s]


Epoch [54/100] Batch 804/804 Using Time: 210.6819            Loss: 0.0010


999it [02:43,  6.12it/s]


Epoch [55/100] Batch 804/804 Using Time: 210.0818            Loss: 0.0010


999it [02:43,  6.12it/s]


Epoch [56/100] Batch 804/804 Using Time: 210.1974            Loss: 0.0010


999it [02:42,  6.13it/s]


Epoch [57/100] Batch 804/804 Using Time: 210.1996            Loss: 0.0010


999it [02:43,  6.13it/s]


Epoch [58/100] Batch 804/804 Using Time: 210.0766            Loss: 0.0010


999it [02:42,  6.13it/s]


Epoch [59/100] Batch 804/804 Using Time: 210.2894            Loss: 0.0009


999it [02:42,  6.13it/s]


Epoch [60/100] Batch 804/804 Using Time: 210.3866            Loss: 0.0009


999it [02:42,  6.13it/s]


Epoch [61/100] Batch 804/804 Using Time: 209.5156            Loss: 0.0010


999it [02:42,  6.16it/s]


Epoch [62/100] Batch 804/804 Using Time: 209.4802            Loss: 0.0010


999it [02:41,  6.17it/s]


Epoch [63/100] Batch 804/804 Using Time: 209.1459            Loss: 0.0010


999it [02:42,  6.16it/s]


Epoch [64/100] Batch 804/804 Using Time: 209.2562            Loss: 0.0010


999it [02:41,  6.17it/s]


Epoch [65/100] Batch 804/804 Using Time: 209.2849            Loss: 0.0010


999it [02:42,  6.16it/s]


Epoch [66/100] Batch 804/804 Using Time: 208.8800            Loss: 0.0010


999it [02:42,  6.16it/s]


Epoch [67/100] Batch 804/804 Using Time: 209.6304            Loss: 0.0009


999it [02:42,  6.15it/s]


Epoch [68/100] Batch 804/804 Using Time: 209.5612            Loss: 0.0009


999it [02:41,  6.18it/s]


Epoch [69/100] Batch 804/804 Using Time: 208.6042            Loss: 0.0009


999it [02:41,  6.17it/s]


Epoch [70/100] Batch 804/804 Using Time: 209.0393            Loss: 0.0010


999it [02:42,  6.16it/s]


Epoch [71/100] Batch 804/804 Using Time: 209.2235            Loss: 0.0009


999it [02:42,  6.17it/s]


Epoch [72/100] Batch 804/804 Using Time: 208.7986            Loss: 0.0009


999it [02:41,  6.17it/s]


Epoch [73/100] Batch 804/804 Using Time: 208.7386            Loss: 0.0010


999it [02:41,  6.17it/s]


Epoch [74/100] Batch 804/804 Using Time: 209.0527            Loss: 0.0009


999it [02:41,  6.18it/s]


Epoch [75/100] Batch 804/804 Using Time: 208.8446            Loss: 0.0009


999it [02:41,  6.17it/s]


Epoch [76/100] Batch 804/804 Using Time: 209.3455            Loss: 0.0009


999it [02:42,  6.16it/s]


Epoch [77/100] Batch 804/804 Using Time: 209.7231            Loss: 0.0010


999it [02:49,  5.88it/s]


Epoch [78/100] Batch 804/804 Using Time: 215.8785            Loss: 0.0009


999it [02:47,  5.97it/s]


Epoch [79/100] Batch 804/804 Using Time: 210.9632            Loss: 0.0009


999it [02:42,  6.14it/s]


Epoch [80/100] Batch 804/804 Using Time: 210.9089            Loss: 0.0009


999it [02:44,  6.06it/s]


Epoch [81/100] Batch 804/804 Using Time: 213.3553            Loss: 0.0009


999it [02:44,  6.09it/s]


Epoch [82/100] Batch 804/804 Using Time: 211.9788            Loss: 0.0009


999it [02:43,  6.09it/s]


Epoch [83/100] Batch 804/804 Using Time: 210.7370            Loss: 0.0010


999it [02:44,  6.07it/s]


Epoch [84/100] Batch 804/804 Using Time: 211.4698            Loss: 0.0009


999it [02:43,  6.11it/s]


Epoch [85/100] Batch 804/804 Using Time: 210.7267            Loss: 0.0009


999it [02:43,  6.10it/s]


Epoch [86/100] Batch 804/804 Using Time: 212.2509            Loss: 0.0009


999it [02:45,  6.04it/s]


Epoch [87/100] Batch 804/804 Using Time: 212.4837            Loss: 0.0009


999it [02:44,  6.09it/s]


Epoch [88/100] Batch 804/804 Using Time: 213.4489            Loss: 0.0009


999it [02:44,  6.07it/s]


Epoch [89/100] Batch 804/804 Using Time: 214.5439            Loss: 0.0009


999it [02:44,  6.08it/s]


Epoch [90/100] Batch 804/804 Using Time: 212.4186            Loss: 0.0009


999it [02:43,  6.10it/s]


Epoch [91/100] Batch 804/804 Using Time: 210.3076            Loss: 0.0009


999it [02:44,  6.07it/s]


Epoch [92/100] Batch 804/804 Using Time: 211.7946            Loss: 0.0009


999it [02:44,  6.07it/s]


Epoch [93/100] Batch 804/804 Using Time: 212.7197            Loss: 0.0009


999it [02:44,  6.06it/s]


Epoch [94/100] Batch 804/804 Using Time: 213.1965            Loss: 0.0009


999it [02:49,  5.88it/s]


Epoch [95/100] Batch 804/804 Using Time: 213.8210            Loss: 0.0009


999it [02:46,  5.99it/s]


Epoch [96/100] Batch 804/804 Using Time: 214.5201            Loss: 0.0009


999it [02:46,  5.98it/s]


Epoch [97/100] Batch 804/804 Using Time: 214.5922            Loss: 0.0009


999it [02:46,  6.00it/s]


Epoch [98/100] Batch 804/804 Using Time: 212.8345            Loss: 0.0009


999it [02:45,  6.03it/s]


Epoch [99/100] Batch 804/804 Using Time: 213.2356            Loss: 0.0009


999it [02:44,  6.07it/s]


Epoch [100/100] Batch 804/804 Using Time: 211.6610            Loss: 0.0009


999it [02:44,  6.08it/s]


# Resume Training

In [7]:
def load_weight( model, weight_path="", index=int):
    # Load the saved weights
    model.load_state_dict(torch.load(f'{weight_path}/{index}.pt'))

## get logs latest training time to count on this number

In [8]:
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
import matplotlib.pyplot as plt

# Function to get data from event accumulator
def get_data(log_dir):
    event_acc = EventAccumulator(log_dir)
    event_acc.Reload()

    tags = event_acc.Tags()['scalars']
    selected_tag = 'traing time'
    tag_values = event_acc.Scalars(selected_tag)
    
    steps = [x.step for x in tag_values]
    values = [x.value for x in tag_values]
    return steps, values

steps, values = get_data(f'logs/{DATASET_NAME}/DDPM/')

print(len(steps))
print(len(values))

print(steps.index(steps[-1]), "value:", steps[-1] )
print(values.index(values[-1]), "value:", values[-1] )
# print(values.index(values[451]), "value:", values[451] )

144
144
143 value: 143
143 value: 30045.22265625


## resume

In [9]:
import time
from modules.ddpm import Diffusion, UNet

import torch.optim as optim
import torch.nn as nn
import torchvision

from torch.utils.tensorboard import SummaryWriter


# resume number
start = steps[-1]
stop = 500 

model = UNet().to(device)

# load weight
load_weight(model=model, weight_path=f'weights/{DATASET_NAME}/DDPM/', index=start)

optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
mse = nn.MSELoss()
diffusion = Diffusion(img_size=IMG_SIZE, device=device)
writer = SummaryWriter(os.path.join(f"logs/{DATASET_NAME}","DDPM"))

# l = len(data_loader)

# get trianing time to resume counting number
time_use = values[-1]

for epoch in range(start,stop):

    lossMean = 0

    # use time for time measurement
    start_time = time.time()

    model.train()
    
    for batch_idx, images in enumerate(data_loader):
        images = images.to(device)
        t = diffusion.sample_timesteps(images.shape[0]).to(device)
        x_t, noise = diffusion.noise_images(images, t)
        predicted_noise = model(x_t, t)
        loss = mse(noise, predicted_noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # sum lose
        lossMean += loss

    end_time = time.time()
    epoch_time = end_time - start_time

    time_use += epoch_time 

    epoch_index = epoch+1

    # calculate mean value
    lossMean = lossMean / len(data_loader.dataset)

    print(
        f"Epoch [{epoch_index}/{stop}] Batch {batch_idx+1}/{len(data_loader)} Using Time: {epoch_time:.4f}\
            Loss: {lossMean:.4f}"
    )

    # tensorboard
    writer.add_scalar("Loss/MSE", lossMean, global_step=epoch+1)
    writer.add_scalar("traing time", time_use, epoch_index)

    weight_path = os.path.join("weights", f'{DATASET_NAME}', "DDPM")
    os.makedirs(weight_path, exist_ok=True)
    torch.save(model.state_dict(), os.path.join("weights", f'{DATASET_NAME}', "DDPM", f"{epoch_index}.pt"))

    model.eval()
    with torch.no_grad():
        x = diffusion.sample(model, n=32).type(dtype=torch.float32)
        img_grid = torchvision.utils.make_grid(x[:32], normalize=True)
        writer.add_image("All/Gen", img_grid, global_step=epoch_index)


  return F.conv2d(input, weight, bias, self.stride,


Epoch [144/500] Batch 804/804 Using Time: 212.2353            Loss: 0.0009


999it [02:37,  6.35it/s]


Epoch [145/500] Batch 804/804 Using Time: 203.1017            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [146/500] Batch 804/804 Using Time: 203.7792            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [147/500] Batch 804/804 Using Time: 203.9590            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [148/500] Batch 804/804 Using Time: 204.5378            Loss: 0.0009


999it [02:38,  6.32it/s]


Epoch [149/500] Batch 804/804 Using Time: 204.3475            Loss: 0.0009


999it [02:38,  6.32it/s]


Epoch [150/500] Batch 804/804 Using Time: 204.3823            Loss: 0.0009


999it [02:38,  6.32it/s]


Epoch [151/500] Batch 804/804 Using Time: 204.6872            Loss: 0.0009


999it [02:38,  6.32it/s]


Epoch [152/500] Batch 804/804 Using Time: 204.5245            Loss: 0.0009


999it [02:37,  6.32it/s]


Epoch [153/500] Batch 804/804 Using Time: 204.3995            Loss: 0.0009


999it [02:38,  6.32it/s]


Epoch [154/500] Batch 804/804 Using Time: 204.4981            Loss: 0.0009


999it [02:38,  6.32it/s]


Epoch [155/500] Batch 804/804 Using Time: 204.4364            Loss: 0.0009


999it [02:37,  6.32it/s]


Epoch [156/500] Batch 804/804 Using Time: 204.4355            Loss: 0.0009


999it [02:38,  6.32it/s]


Epoch [157/500] Batch 804/804 Using Time: 204.6036            Loss: 0.0009


999it [02:38,  6.32it/s]


Epoch [158/500] Batch 804/804 Using Time: 204.4263            Loss: 0.0009


999it [02:38,  6.32it/s]


Epoch [159/500] Batch 804/804 Using Time: 204.7881            Loss: 0.0009


999it [02:38,  6.32it/s]


Epoch [160/500] Batch 804/804 Using Time: 204.7667            Loss: 0.0009


999it [02:38,  6.32it/s]


Epoch [161/500] Batch 804/804 Using Time: 204.1734            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [162/500] Batch 804/804 Using Time: 203.9578            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [163/500] Batch 804/804 Using Time: 204.1449            Loss: 0.0009


999it [02:38,  6.32it/s]


Epoch [164/500] Batch 804/804 Using Time: 205.0303            Loss: 0.0009


999it [02:38,  6.32it/s]


Epoch [165/500] Batch 804/804 Using Time: 204.6589            Loss: 0.0009


999it [02:37,  6.32it/s]


Epoch [166/500] Batch 804/804 Using Time: 204.7612            Loss: 0.0009


999it [02:38,  6.32it/s]


Epoch [167/500] Batch 804/804 Using Time: 204.8768            Loss: 0.0009


999it [02:38,  6.32it/s]


Epoch [168/500] Batch 804/804 Using Time: 204.8127            Loss: 0.0008


999it [02:38,  6.32it/s]


Epoch [169/500] Batch 804/804 Using Time: 204.4880            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [170/500] Batch 804/804 Using Time: 204.2306            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [171/500] Batch 804/804 Using Time: 204.4476            Loss: 0.0009


999it [02:37,  6.32it/s]


Epoch [172/500] Batch 804/804 Using Time: 203.9617            Loss: 0.0008


999it [02:37,  6.32it/s]


Epoch [173/500] Batch 804/804 Using Time: 203.7339            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [174/500] Batch 804/804 Using Time: 203.6836            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [175/500] Batch 804/804 Using Time: 203.8369            Loss: 0.0009


999it [02:37,  6.32it/s]


Epoch [176/500] Batch 804/804 Using Time: 203.7641            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [177/500] Batch 804/804 Using Time: 204.1363            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [178/500] Batch 804/804 Using Time: 204.4376            Loss: 0.0009


999it [02:37,  6.32it/s]


Epoch [179/500] Batch 804/804 Using Time: 204.4121            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [180/500] Batch 804/804 Using Time: 204.3772            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [181/500] Batch 804/804 Using Time: 203.2974            Loss: 0.0009


999it [02:37,  6.34it/s]


Epoch [182/500] Batch 804/804 Using Time: 203.3167            Loss: 0.0009


999it [02:37,  6.34it/s]


Epoch [183/500] Batch 804/804 Using Time: 203.2552            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [184/500] Batch 804/804 Using Time: 203.2279            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [185/500] Batch 804/804 Using Time: 203.2494            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [186/500] Batch 804/804 Using Time: 203.2305            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [187/500] Batch 804/804 Using Time: 203.2903            Loss: 0.0009


999it [02:37,  6.34it/s]


Epoch [188/500] Batch 804/804 Using Time: 203.2487            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [189/500] Batch 804/804 Using Time: 203.1893            Loss: 0.0009


999it [02:37,  6.34it/s]


Epoch [190/500] Batch 804/804 Using Time: 203.3426            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [191/500] Batch 804/804 Using Time: 203.2556            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [192/500] Batch 804/804 Using Time: 203.1936            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [193/500] Batch 804/804 Using Time: 203.2427            Loss: 0.0009


999it [02:37,  6.34it/s]


Epoch [194/500] Batch 804/804 Using Time: 203.2138            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [195/500] Batch 804/804 Using Time: 203.4817            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [196/500] Batch 804/804 Using Time: 203.4767            Loss: 0.0009


999it [02:37,  6.34it/s]


Epoch [197/500] Batch 804/804 Using Time: 203.5374            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [198/500] Batch 804/804 Using Time: 203.3163            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [199/500] Batch 804/804 Using Time: 203.5575            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [200/500] Batch 804/804 Using Time: 203.3255            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [201/500] Batch 804/804 Using Time: 203.2448            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [202/500] Batch 804/804 Using Time: 203.2966            Loss: 0.0009


999it [02:37,  6.34it/s]


Epoch [203/500] Batch 804/804 Using Time: 203.2554            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [204/500] Batch 804/804 Using Time: 203.2864            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [205/500] Batch 804/804 Using Time: 203.2863            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [206/500] Batch 804/804 Using Time: 203.2350            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [207/500] Batch 804/804 Using Time: 203.3927            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [208/500] Batch 804/804 Using Time: 203.3770            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [209/500] Batch 804/804 Using Time: 203.1953            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [210/500] Batch 804/804 Using Time: 203.6480            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [211/500] Batch 804/804 Using Time: 203.2843            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [212/500] Batch 804/804 Using Time: 203.2367            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [213/500] Batch 804/804 Using Time: 203.1879            Loss: 0.0009


999it [02:37,  6.35it/s]


Epoch [214/500] Batch 804/804 Using Time: 203.4602            Loss: 0.0009


999it [02:37,  6.34it/s]


Epoch [215/500] Batch 804/804 Using Time: 203.1798            Loss: 0.0009


999it [02:37,  6.34it/s]


Epoch [216/500] Batch 804/804 Using Time: 203.3089            Loss: 0.0009


999it [02:37,  6.34it/s]


Epoch [217/500] Batch 804/804 Using Time: 203.2667            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [218/500] Batch 804/804 Using Time: 203.5990            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [219/500] Batch 804/804 Using Time: 203.6944            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [220/500] Batch 804/804 Using Time: 203.2448            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [221/500] Batch 804/804 Using Time: 203.2792            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [222/500] Batch 804/804 Using Time: 203.3132            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [223/500] Batch 804/804 Using Time: 203.2287            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [224/500] Batch 804/804 Using Time: 203.5933            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [225/500] Batch 804/804 Using Time: 203.3576            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [226/500] Batch 804/804 Using Time: 203.3356            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [227/500] Batch 804/804 Using Time: 203.4207            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [228/500] Batch 804/804 Using Time: 203.3922            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [229/500] Batch 804/804 Using Time: 203.6023            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [230/500] Batch 804/804 Using Time: 203.3428            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [231/500] Batch 804/804 Using Time: 203.2441            Loss: 0.0009


999it [02:37,  6.34it/s]


Epoch [232/500] Batch 804/804 Using Time: 203.1644            Loss: 0.0009


999it [02:37,  6.34it/s]


Epoch [233/500] Batch 804/804 Using Time: 203.2770            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [234/500] Batch 804/804 Using Time: 203.1857            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [235/500] Batch 804/804 Using Time: 203.5286            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [236/500] Batch 804/804 Using Time: 203.6180            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [237/500] Batch 804/804 Using Time: 203.3690            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [238/500] Batch 804/804 Using Time: 203.7758            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [239/500] Batch 804/804 Using Time: 203.8167            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [240/500] Batch 804/804 Using Time: 203.4440            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [241/500] Batch 804/804 Using Time: 204.5519            Loss: 0.0009


999it [02:38,  6.31it/s]


Epoch [242/500] Batch 804/804 Using Time: 205.1283            Loss: 0.0009


999it [02:38,  6.31it/s]


Epoch [243/500] Batch 804/804 Using Time: 204.6460            Loss: 0.0008


999it [02:38,  6.31it/s]


Epoch [244/500] Batch 804/804 Using Time: 204.0579            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [245/500] Batch 804/804 Using Time: 203.6324            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [246/500] Batch 804/804 Using Time: 203.5444            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [247/500] Batch 804/804 Using Time: 203.3208            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [248/500] Batch 804/804 Using Time: 203.3747            Loss: 0.0009


999it [02:37,  6.34it/s]


Epoch [249/500] Batch 804/804 Using Time: 203.4300            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [250/500] Batch 804/804 Using Time: 203.3059            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [251/500] Batch 804/804 Using Time: 203.3738            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [252/500] Batch 804/804 Using Time: 203.3996            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [253/500] Batch 804/804 Using Time: 203.3568            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [254/500] Batch 804/804 Using Time: 203.2618            Loss: 0.0009


999it [02:37,  6.32it/s]


Epoch [255/500] Batch 804/804 Using Time: 203.7006            Loss: 0.0009


999it [02:37,  6.32it/s]


Epoch [256/500] Batch 804/804 Using Time: 203.5865            Loss: 0.0009


999it [02:38,  6.32it/s]


Epoch [257/500] Batch 804/804 Using Time: 203.4951            Loss: 0.0008


999it [02:37,  6.32it/s]


Epoch [258/500] Batch 804/804 Using Time: 203.6908            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [259/500] Batch 804/804 Using Time: 203.6516            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [260/500] Batch 804/804 Using Time: 203.5695            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [261/500] Batch 804/804 Using Time: 203.5786            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [262/500] Batch 804/804 Using Time: 203.7258            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [263/500] Batch 804/804 Using Time: 203.5296            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [264/500] Batch 804/804 Using Time: 203.7521            Loss: 0.0008


999it [02:37,  6.32it/s]


Epoch [265/500] Batch 804/804 Using Time: 203.7050            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [266/500] Batch 804/804 Using Time: 203.7136            Loss: 0.0008


999it [02:37,  6.32it/s]


Epoch [267/500] Batch 804/804 Using Time: 204.2825            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [268/500] Batch 804/804 Using Time: 203.3341            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [269/500] Batch 804/804 Using Time: 205.1730            Loss: 0.0009


999it [02:39,  6.28it/s]


Epoch [270/500] Batch 804/804 Using Time: 206.3722            Loss: 0.0008


999it [02:38,  6.30it/s]


Epoch [271/500] Batch 804/804 Using Time: 205.0791            Loss: 0.0008


999it [02:38,  6.31it/s]


Epoch [272/500] Batch 804/804 Using Time: 205.5822            Loss: 0.0008


999it [02:38,  6.31it/s]


Epoch [273/500] Batch 804/804 Using Time: 204.4395            Loss: 0.0008


999it [02:38,  6.32it/s]


Epoch [274/500] Batch 804/804 Using Time: 204.9468            Loss: 0.0008


999it [02:38,  6.32it/s]


Epoch [275/500] Batch 804/804 Using Time: 204.0055            Loss: 0.0008


999it [02:37,  6.32it/s]


Epoch [276/500] Batch 804/804 Using Time: 204.0868            Loss: 0.0008


999it [02:38,  6.32it/s]


Epoch [277/500] Batch 804/804 Using Time: 204.6291            Loss: 0.0008


999it [02:38,  6.32it/s]


Epoch [278/500] Batch 804/804 Using Time: 205.0566            Loss: 0.0008


999it [02:38,  6.30it/s]


Epoch [279/500] Batch 804/804 Using Time: 204.8413            Loss: 0.0008


999it [02:38,  6.31it/s]


Epoch [280/500] Batch 804/804 Using Time: 204.7982            Loss: 0.0009


999it [02:38,  6.30it/s]


Epoch [281/500] Batch 804/804 Using Time: 204.0686            Loss: 0.0009


999it [02:38,  6.31it/s]


Epoch [282/500] Batch 804/804 Using Time: 204.6291            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [283/500] Batch 804/804 Using Time: 204.3502            Loss: 0.0008


999it [02:38,  6.30it/s]


Epoch [284/500] Batch 804/804 Using Time: 204.8042            Loss: 0.0008


999it [02:38,  6.31it/s]


Epoch [285/500] Batch 804/804 Using Time: 204.0607            Loss: 0.0008


999it [02:37,  6.32it/s]


Epoch [286/500] Batch 804/804 Using Time: 203.4385            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [287/500] Batch 804/804 Using Time: 203.4934            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [288/500] Batch 804/804 Using Time: 203.7656            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [289/500] Batch 804/804 Using Time: 203.5937            Loss: 0.0009


999it [02:38,  6.32it/s]


Epoch [290/500] Batch 804/804 Using Time: 203.5139            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [291/500] Batch 804/804 Using Time: 203.5666            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [292/500] Batch 804/804 Using Time: 203.4764            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [293/500] Batch 804/804 Using Time: 203.4887            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [294/500] Batch 804/804 Using Time: 203.3828            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [295/500] Batch 804/804 Using Time: 203.8996            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [296/500] Batch 804/804 Using Time: 203.4398            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [297/500] Batch 804/804 Using Time: 203.8466            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [298/500] Batch 804/804 Using Time: 203.5526            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [299/500] Batch 804/804 Using Time: 203.5596            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [300/500] Batch 804/804 Using Time: 203.4413            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [301/500] Batch 804/804 Using Time: 203.4006            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [302/500] Batch 804/804 Using Time: 203.7641            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [303/500] Batch 804/804 Using Time: 203.4987            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [304/500] Batch 804/804 Using Time: 203.5044            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [305/500] Batch 804/804 Using Time: 203.6350            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [306/500] Batch 804/804 Using Time: 203.4259            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [307/500] Batch 804/804 Using Time: 203.4538            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [308/500] Batch 804/804 Using Time: 203.4731            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [309/500] Batch 804/804 Using Time: 203.5737            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [310/500] Batch 804/804 Using Time: 203.4690            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [311/500] Batch 804/804 Using Time: 203.7951            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [312/500] Batch 804/804 Using Time: 203.9218            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [313/500] Batch 804/804 Using Time: 203.4972            Loss: 0.0008


999it [02:37,  6.32it/s]


Epoch [314/500] Batch 804/804 Using Time: 203.4572            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [315/500] Batch 804/804 Using Time: 203.4431            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [316/500] Batch 804/804 Using Time: 204.3680            Loss: 0.0008


999it [02:38,  6.31it/s]


Epoch [317/500] Batch 804/804 Using Time: 205.1719            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [318/500] Batch 804/804 Using Time: 203.8874            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [319/500] Batch 804/804 Using Time: 203.6224            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [320/500] Batch 804/804 Using Time: 203.8167            Loss: 0.0008


999it [02:38,  6.32it/s]


Epoch [321/500] Batch 804/804 Using Time: 205.1520            Loss: 0.0008


999it [02:38,  6.32it/s]


Epoch [322/500] Batch 804/804 Using Time: 205.1875            Loss: 0.0009


999it [02:38,  6.31it/s]


Epoch [323/500] Batch 804/804 Using Time: 205.2711            Loss: 0.0008


999it [02:39,  6.28it/s]


Epoch [324/500] Batch 804/804 Using Time: 207.2343            Loss: 0.0008


999it [02:40,  6.23it/s]


Epoch [325/500] Batch 804/804 Using Time: 206.6530            Loss: 0.0008


999it [02:39,  6.28it/s]


Epoch [326/500] Batch 804/804 Using Time: 205.2558            Loss: 0.0008


999it [02:39,  6.27it/s]


Epoch [327/500] Batch 804/804 Using Time: 204.9228            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [328/500] Batch 804/804 Using Time: 204.7877            Loss: 0.0009


999it [02:38,  6.30it/s]


Epoch [329/500] Batch 804/804 Using Time: 204.5557            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [330/500] Batch 804/804 Using Time: 204.1064            Loss: 0.0008


999it [02:38,  6.30it/s]


Epoch [331/500] Batch 804/804 Using Time: 205.7933            Loss: 0.0008


999it [02:38,  6.31it/s]


Epoch [332/500] Batch 804/804 Using Time: 205.9931            Loss: 0.0008


999it [02:39,  6.26it/s]


Epoch [333/500] Batch 804/804 Using Time: 206.4847            Loss: 0.0008


999it [02:39,  6.26it/s]


Epoch [334/500] Batch 804/804 Using Time: 207.0015            Loss: 0.0008


999it [02:39,  6.24it/s]


Epoch [335/500] Batch 804/804 Using Time: 206.5390            Loss: 0.0008


999it [02:39,  6.26it/s]


Epoch [336/500] Batch 804/804 Using Time: 206.3893            Loss: 0.0008


999it [02:39,  6.27it/s]


Epoch [337/500] Batch 804/804 Using Time: 206.1156            Loss: 0.0009


999it [02:38,  6.30it/s]


Epoch [338/500] Batch 804/804 Using Time: 206.7842            Loss: 0.0008


999it [02:39,  6.27it/s]


Epoch [339/500] Batch 804/804 Using Time: 206.3738            Loss: 0.0008


999it [02:38,  6.29it/s]


Epoch [340/500] Batch 804/804 Using Time: 205.6448            Loss: 0.0009


999it [02:38,  6.29it/s]


Epoch [341/500] Batch 804/804 Using Time: 205.5866            Loss: 0.0008


999it [02:38,  6.30it/s]


Epoch [342/500] Batch 804/804 Using Time: 204.0611            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [343/500] Batch 804/804 Using Time: 203.4129            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [344/500] Batch 804/804 Using Time: 203.4581            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [345/500] Batch 804/804 Using Time: 204.6084            Loss: 0.0009


999it [02:38,  6.31it/s]


Epoch [346/500] Batch 804/804 Using Time: 205.8466            Loss: 0.0008


999it [02:39,  6.26it/s]


Epoch [347/500] Batch 804/804 Using Time: 207.0615            Loss: 0.0008


999it [02:39,  6.27it/s]


Epoch [348/500] Batch 804/804 Using Time: 207.2476            Loss: 0.0008


999it [02:39,  6.25it/s]


Epoch [349/500] Batch 804/804 Using Time: 207.4721            Loss: 0.0008


999it [02:39,  6.25it/s]


Epoch [350/500] Batch 804/804 Using Time: 206.9656            Loss: 0.0008


999it [02:42,  6.16it/s]


Epoch [351/500] Batch 804/804 Using Time: 208.6052            Loss: 0.0008


999it [02:38,  6.30it/s]


Epoch [352/500] Batch 804/804 Using Time: 207.1196            Loss: 0.0008


999it [02:38,  6.32it/s]


Epoch [353/500] Batch 804/804 Using Time: 205.6908            Loss: 0.0008


999it [02:38,  6.31it/s]


Epoch [354/500] Batch 804/804 Using Time: 205.8951            Loss: 0.0008


999it [02:38,  6.30it/s]


Epoch [355/500] Batch 804/804 Using Time: 205.5556            Loss: 0.0008


999it [02:38,  6.30it/s]


Epoch [356/500] Batch 804/804 Using Time: 205.7292            Loss: 0.0008


999it [02:38,  6.30it/s]


Epoch [357/500] Batch 804/804 Using Time: 207.0294            Loss: 0.0008


999it [02:39,  6.27it/s]


Epoch [358/500] Batch 804/804 Using Time: 207.6184            Loss: 0.0008


999it [02:41,  6.17it/s]


Epoch [359/500] Batch 804/804 Using Time: 204.6923            Loss: 0.0009


999it [02:38,  6.31it/s]


Epoch [360/500] Batch 804/804 Using Time: 207.6574            Loss: 0.0008


999it [02:39,  6.26it/s]


Epoch [361/500] Batch 804/804 Using Time: 206.5252            Loss: 0.0008


999it [02:38,  6.29it/s]


Epoch [362/500] Batch 804/804 Using Time: 206.3290            Loss: 0.0008


999it [02:38,  6.29it/s]


Epoch [363/500] Batch 804/804 Using Time: 206.3382            Loss: 0.0008


999it [02:38,  6.30it/s]


Epoch [364/500] Batch 804/804 Using Time: 206.3985            Loss: 0.0008


999it [02:38,  6.30it/s]


Epoch [365/500] Batch 804/804 Using Time: 205.7012            Loss: 0.0008


999it [02:38,  6.32it/s]


Epoch [366/500] Batch 804/804 Using Time: 205.0575            Loss: 0.0008


999it [02:38,  6.30it/s]


Epoch [367/500] Batch 804/804 Using Time: 205.6345            Loss: 0.0008


999it [02:40,  6.21it/s]


Epoch [368/500] Batch 804/804 Using Time: 208.5652            Loss: 0.0008


999it [02:39,  6.26it/s]


Epoch [369/500] Batch 804/804 Using Time: 206.6568            Loss: 0.0008


999it [02:39,  6.27it/s]


Epoch [370/500] Batch 804/804 Using Time: 206.8654            Loss: 0.0008


999it [02:38,  6.29it/s]


Epoch [371/500] Batch 804/804 Using Time: 204.4085            Loss: 0.0008


999it [02:38,  6.32it/s]


Epoch [372/500] Batch 804/804 Using Time: 204.6717            Loss: 0.0008


999it [02:38,  6.32it/s]


Epoch [373/500] Batch 804/804 Using Time: 204.6400            Loss: 0.0008


999it [02:38,  6.29it/s]


Epoch [374/500] Batch 804/804 Using Time: 205.9635            Loss: 0.0008


999it [02:38,  6.30it/s]


Epoch [375/500] Batch 804/804 Using Time: 205.8337            Loss: 0.0008


999it [02:39,  6.28it/s]


Epoch [376/500] Batch 804/804 Using Time: 206.2275            Loss: 0.0008


999it [02:41,  6.20it/s]


Epoch [377/500] Batch 804/804 Using Time: 204.0037            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [378/500] Batch 804/804 Using Time: 204.1532            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [379/500] Batch 804/804 Using Time: 203.6091            Loss: 0.0008


999it [02:37,  6.32it/s]


Epoch [380/500] Batch 804/804 Using Time: 203.6633            Loss: 0.0008


999it [02:37,  6.32it/s]


Epoch [381/500] Batch 804/804 Using Time: 204.0077            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [382/500] Batch 804/804 Using Time: 203.6934            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [383/500] Batch 804/804 Using Time: 203.6663            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [384/500] Batch 804/804 Using Time: 203.5537            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [385/500] Batch 804/804 Using Time: 203.5330            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [386/500] Batch 804/804 Using Time: 203.7057            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [387/500] Batch 804/804 Using Time: 203.5588            Loss: 0.0008


999it [02:37,  6.32it/s]


Epoch [388/500] Batch 804/804 Using Time: 203.5614            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [389/500] Batch 804/804 Using Time: 203.6673            Loss: 0.0008


999it [02:37,  6.32it/s]


Epoch [390/500] Batch 804/804 Using Time: 203.7595            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [391/500] Batch 804/804 Using Time: 203.5127            Loss: 0.0008


999it [02:37,  6.32it/s]


Epoch [392/500] Batch 804/804 Using Time: 203.5719            Loss: 0.0008


999it [02:37,  6.32it/s]


Epoch [393/500] Batch 804/804 Using Time: 203.5748            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [394/500] Batch 804/804 Using Time: 203.7058            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [395/500] Batch 804/804 Using Time: 203.9174            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [396/500] Batch 804/804 Using Time: 203.8900            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [397/500] Batch 804/804 Using Time: 203.5964            Loss: 0.0008


999it [02:37,  6.32it/s]


Epoch [398/500] Batch 804/804 Using Time: 203.7188            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [399/500] Batch 804/804 Using Time: 203.5611            Loss: 0.0008


999it [02:37,  6.32it/s]


Epoch [400/500] Batch 804/804 Using Time: 203.6112            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [401/500] Batch 804/804 Using Time: 203.5956            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [402/500] Batch 804/804 Using Time: 203.6831            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [403/500] Batch 804/804 Using Time: 203.5082            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [404/500] Batch 804/804 Using Time: 203.6314            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [405/500] Batch 804/804 Using Time: 203.9256            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [406/500] Batch 804/804 Using Time: 204.0376            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [407/500] Batch 804/804 Using Time: 203.6678            Loss: 0.0008


999it [02:37,  6.32it/s]


Epoch [408/500] Batch 804/804 Using Time: 203.4843            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [409/500] Batch 804/804 Using Time: 203.3671            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [410/500] Batch 804/804 Using Time: 203.3579            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [411/500] Batch 804/804 Using Time: 203.4996            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [412/500] Batch 804/804 Using Time: 203.8136            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [413/500] Batch 804/804 Using Time: 204.0610            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [414/500] Batch 804/804 Using Time: 203.6768            Loss: 0.0009


999it [02:37,  6.33it/s]


Epoch [415/500] Batch 804/804 Using Time: 203.6007            Loss: 0.0008


999it [02:37,  6.32it/s]


Epoch [416/500] Batch 804/804 Using Time: 203.3667            Loss: 0.0008


999it [02:37,  6.34it/s]


Epoch [417/500] Batch 804/804 Using Time: 203.5245            Loss: 0.0008


999it [02:37,  6.32it/s]


Epoch [418/500] Batch 804/804 Using Time: 203.5095            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [419/500] Batch 804/804 Using Time: 203.5577            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [420/500] Batch 804/804 Using Time: 203.4437            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [421/500] Batch 804/804 Using Time: 203.6192            Loss: 0.0008


999it [02:37,  6.32it/s]


Epoch [422/500] Batch 804/804 Using Time: 203.8068            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [423/500] Batch 804/804 Using Time: 204.0061            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [424/500] Batch 804/804 Using Time: 203.7313            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [425/500] Batch 804/804 Using Time: 203.6650            Loss: 0.0008


999it [02:37,  6.32it/s]


Epoch [426/500] Batch 804/804 Using Time: 203.9895            Loss: 0.0008


999it [02:37,  6.32it/s]


Epoch [427/500] Batch 804/804 Using Time: 203.8369            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [428/500] Batch 804/804 Using Time: 204.0812            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [429/500] Batch 804/804 Using Time: 203.8458            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [430/500] Batch 804/804 Using Time: 203.7205            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [431/500] Batch 804/804 Using Time: 203.6527            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [432/500] Batch 804/804 Using Time: 203.6127            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [433/500] Batch 804/804 Using Time: 203.6074            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [434/500] Batch 804/804 Using Time: 203.9053            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [435/500] Batch 804/804 Using Time: 203.9097            Loss: 0.0008


999it [02:38,  6.32it/s]


Epoch [436/500] Batch 804/804 Using Time: 204.8577            Loss: 0.0008


999it [02:38,  6.32it/s]


Epoch [437/500] Batch 804/804 Using Time: 204.9179            Loss: 0.0008


999it [02:38,  6.31it/s]


Epoch [438/500] Batch 804/804 Using Time: 204.8693            Loss: 0.0008


999it [02:38,  6.32it/s]


Epoch [439/500] Batch 804/804 Using Time: 204.5573            Loss: 0.0008


999it [02:38,  6.31it/s]


Epoch [440/500] Batch 804/804 Using Time: 204.6052            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [441/500] Batch 804/804 Using Time: 204.2508            Loss: 0.0008


999it [02:38,  6.32it/s]


Epoch [442/500] Batch 804/804 Using Time: 205.0394            Loss: 0.0008


999it [02:38,  6.31it/s]


Epoch [443/500] Batch 804/804 Using Time: 204.7870            Loss: 0.0008


999it [02:38,  6.32it/s]


Epoch [444/500] Batch 804/804 Using Time: 205.2068            Loss: 0.0008


999it [02:38,  6.31it/s]


Epoch [445/500] Batch 804/804 Using Time: 208.2720            Loss: 0.0008


999it [02:38,  6.31it/s]


Epoch [446/500] Batch 804/804 Using Time: 205.2439            Loss: 0.0008


999it [02:38,  6.31it/s]


Epoch [447/500] Batch 804/804 Using Time: 205.4398            Loss: 0.0008


999it [02:38,  6.30it/s]


Epoch [448/500] Batch 804/804 Using Time: 205.5824            Loss: 0.0008


999it [02:38,  6.30it/s]


Epoch [449/500] Batch 804/804 Using Time: 205.2027            Loss: 0.0008


999it [02:38,  6.31it/s]


Epoch [450/500] Batch 804/804 Using Time: 205.1266            Loss: 0.0008


999it [02:38,  6.31it/s]


Epoch [451/500] Batch 804/804 Using Time: 205.0704            Loss: 0.0008


999it [02:38,  6.32it/s]


Epoch [452/500] Batch 804/804 Using Time: 205.5392            Loss: 0.0008


999it [02:38,  6.30it/s]


Epoch [453/500] Batch 804/804 Using Time: 205.3329            Loss: 0.0008


999it [02:38,  6.30it/s]


Epoch [454/500] Batch 804/804 Using Time: 205.4942            Loss: 0.0008


999it [02:38,  6.30it/s]


Epoch [455/500] Batch 804/804 Using Time: 206.0263            Loss: 0.0008


999it [02:38,  6.30it/s]


Epoch [456/500] Batch 804/804 Using Time: 205.6597            Loss: 0.0008


999it [02:38,  6.29it/s]


Epoch [457/500] Batch 804/804 Using Time: 205.9471            Loss: 0.0008


999it [02:38,  6.30it/s]


Epoch [458/500] Batch 804/804 Using Time: 205.5895            Loss: 0.0008


999it [02:38,  6.29it/s]


Epoch [459/500] Batch 804/804 Using Time: 206.0304            Loss: 0.0008


999it [02:38,  6.29it/s]


Epoch [460/500] Batch 804/804 Using Time: 205.6084            Loss: 0.0008


999it [02:38,  6.29it/s]


Epoch [461/500] Batch 804/804 Using Time: 205.5526            Loss: 0.0008


999it [02:38,  6.29it/s]


Epoch [462/500] Batch 804/804 Using Time: 205.7687            Loss: 0.0009


999it [02:38,  6.29it/s]


Epoch [463/500] Batch 804/804 Using Time: 205.7500            Loss: 0.0008


999it [02:38,  6.31it/s]


Epoch [464/500] Batch 804/804 Using Time: 205.4485            Loss: 0.0008


999it [02:38,  6.30it/s]


Epoch [465/500] Batch 804/804 Using Time: 205.5054            Loss: 0.0008


999it [02:38,  6.30it/s]


Epoch [466/500] Batch 804/804 Using Time: 205.5201            Loss: 0.0008


999it [02:38,  6.29it/s]


Epoch [467/500] Batch 804/804 Using Time: 205.6305            Loss: 0.0008


999it [02:38,  6.29it/s]


Epoch [468/500] Batch 804/804 Using Time: 205.6912            Loss: 0.0008


999it [02:38,  6.29it/s]


Epoch [469/500] Batch 804/804 Using Time: 205.5734            Loss: 0.0008


999it [02:38,  6.30it/s]


Epoch [470/500] Batch 804/804 Using Time: 205.6741            Loss: 0.0008


999it [02:38,  6.30it/s]


Epoch [471/500] Batch 804/804 Using Time: 205.9629            Loss: 0.0008


999it [02:38,  6.29it/s]


Epoch [472/500] Batch 804/804 Using Time: 205.3689            Loss: 0.0008


999it [02:38,  6.31it/s]


Epoch [473/500] Batch 804/804 Using Time: 205.0955            Loss: 0.0008


999it [02:38,  6.31it/s]


Epoch [474/500] Batch 804/804 Using Time: 205.1291            Loss: 0.0008


999it [02:38,  6.31it/s]


Epoch [475/500] Batch 804/804 Using Time: 205.1303            Loss: 0.0008


999it [02:38,  6.31it/s]


Epoch [476/500] Batch 804/804 Using Time: 205.4316            Loss: 0.0008


999it [02:38,  6.30it/s]


Epoch [477/500] Batch 804/804 Using Time: 205.3526            Loss: 0.0008


999it [02:38,  6.31it/s]


Epoch [478/500] Batch 804/804 Using Time: 205.8437            Loss: 0.0008


999it [02:38,  6.30it/s]


Epoch [479/500] Batch 804/804 Using Time: 205.3270            Loss: 0.0008


999it [02:38,  6.31it/s]


Epoch [480/500] Batch 804/804 Using Time: 205.0126            Loss: 0.0008


999it [02:38,  6.31it/s]


Epoch [481/500] Batch 804/804 Using Time: 205.2000            Loss: 0.0008


999it [02:38,  6.31it/s]


Epoch [482/500] Batch 804/804 Using Time: 205.4069            Loss: 0.0008


999it [02:38,  6.32it/s]


Epoch [483/500] Batch 804/804 Using Time: 205.2428            Loss: 0.0008


999it [02:38,  6.30it/s]


Epoch [484/500] Batch 804/804 Using Time: 205.4501            Loss: 0.0008


999it [02:38,  6.31it/s]


Epoch [485/500] Batch 804/804 Using Time: 205.4306            Loss: 0.0008


999it [02:38,  6.30it/s]


Epoch [486/500] Batch 804/804 Using Time: 205.3453            Loss: 0.0008


999it [02:38,  6.30it/s]


Epoch [487/500] Batch 804/804 Using Time: 204.7842            Loss: 0.0008


999it [02:37,  6.32it/s]


Epoch [488/500] Batch 804/804 Using Time: 204.2873            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [489/500] Batch 804/804 Using Time: 204.1747            Loss: 0.0008


999it [02:37,  6.33it/s]


Epoch [490/500] Batch 804/804 Using Time: 204.9579            Loss: 0.0008


999it [02:39,  6.28it/s]


Epoch [491/500] Batch 804/804 Using Time: 205.6319            Loss: 0.0008


999it [02:38,  6.29it/s]


Epoch [492/500] Batch 804/804 Using Time: 205.7036            Loss: 0.0008


999it [02:38,  6.28it/s]


Epoch [493/500] Batch 804/804 Using Time: 206.1108            Loss: 0.0008


999it [02:39,  6.27it/s]


Epoch [494/500] Batch 804/804 Using Time: 205.3226            Loss: 0.0008


999it [02:39,  6.27it/s]


Epoch [495/500] Batch 804/804 Using Time: 207.8913            Loss: 0.0008


999it [02:40,  6.22it/s]


Epoch [496/500] Batch 804/804 Using Time: 208.1886            Loss: 0.0008


999it [02:40,  6.24it/s]


Epoch [497/500] Batch 804/804 Using Time: 207.2019            Loss: 0.0008


999it [02:40,  6.24it/s]


Epoch [498/500] Batch 804/804 Using Time: 207.0071            Loss: 0.0008


999it [02:39,  6.25it/s]


Epoch [499/500] Batch 804/804 Using Time: 206.9759            Loss: 0.0008


999it [02:39,  6.25it/s]


Epoch [500/500] Batch 804/804 Using Time: 206.9300            Loss: 0.0008


999it [02:40,  6.23it/s]


# Measure FID

use this implementation: https://github.com/mseitzer/pytorch-fid/tree/master

## define FID measurement function 

In [5]:
import subprocess
import torch
from torchvision.utils import save_image
import re
import os
import gc
from modules.ddpm import  Diffusion


# Create a function to run the FID script
def run_fid(real_path, gen_path):
    command = ["python", "-m", "pytorch_fid", real_path, gen_path]

    output = subprocess.run(command, capture_output=True, text=True)

    # Extract the FID score using regular expressions
    output = output.stdout

    fid_score_match = re.search(r"FID:\s+(-?\d+\.\d+)", output)
    
    if fid_score_match:
        fid_score = float(fid_score_match.group(1))
        
        return fid_score

    else:
        print("FID score not found in the output.")

def generate_images(model, sample_size=100*8, batch_size=8, device="cuda", real_path=f"dataset/{DATASET_NAME}/", gen_path=f"generated_images/{DATASET_NAME}/DCGAN"):

    diffusion = Diffusion(img_size=IMG_SIZE, device=device)
    
    # Set the model to evaluation mode
    model.eval()

    # Create a folder for generated images if it doesn't exist
    os.makedirs(gen_path, exist_ok=True)

    # Generate images
    with torch.no_grad():
        for i in range(0, sample_size, batch_size):
            
            x = diffusion.sample(model, n=batch_size).type(dtype=torch.float32)

            for j in range(batch_size):
                save_image(x[j], f"{gen_path}/{i+j}.png", normalize=True)
        
    # model = None
    # gc.collect() 

    with torch.no_grad():
        del model
        torch.cuda.empty_cache()



## Load weight of model to measure FID score

In [10]:
def load_weight( model, weight_path="", index=int):
    # Load the saved weights
    model.load_state_dict(torch.load(f'{weight_path}/{index}.pt'))
    

In [8]:
from modules.ddpm import UNet
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter(f'logs/{DATASET_NAME}/DDPM/')

# Specify the directory path
directory_path = Path(f'weights/{DATASET_NAME}/DDPM/')

# Get the list of files ending with ".pth"
# file_list = list(directory_path.glob('*.pt'))

sample_size = 5000

first_number = 1
last_number = 500
frequency = 8  # Specify the frequency of numbers in between

numbers = [first_number] + [first_number + (last_number - first_number) // (frequency + 1) * i for i in range(1, frequency + 1)] + [last_number]

for i in numbers:

    index_weight = i

    # model = UNet().to(device)

    # load_weight(model, weight_path=directory_path, index=index_weight)

    # generate_images(model, sample_size, batch_size=64, device=f'{device}', gen_path=f'gen/{DATASET_NAME}/DDPM/{index_weight}')

    # FID measurement
    fid_score = run_fid(f"{DATASET}", f"gen/{DATASET_NAME}/DDPM/{index_weight}")
    print(f'Epoch: {index_weight} FID score: {fid_score}')
    writer.add_scalar("Metrics/FID Score", fid_score, index_weight)


Epoch: 1 FID score: 263.81401218110216
Epoch: 8 FID score: 308.22987261783453
Epoch: 15 FID score: 178.0814243691487
Epoch: 22 FID score: 292.82191242531616
Epoch: 29 FID score: 240.96353239079892
Epoch: 36 FID score: 204.16533503679705
Epoch: 43 FID score: 252.44504759055448
Epoch: 50 FID score: 219.4068309436546
Epoch: 57 FID score: 215.58892887369643
Epoch: 64 FID score: 141.97673180449073
Epoch: 71 FID score: 167.43182430718068
Epoch: 78 FID score: 232.9618261505923
Epoch: 85 FID score: 116.75266722945042
Epoch: 92 FID score: 206.89526886624148
Epoch: 99 FID score: 103.21029417948219


# subprocess

In [None]:
import subprocess

index = 1
# Define the command to be executed
command = [
    "python",
    "modules/ddpm_sample.py",
    "--sample-size",
    "900",
    "--batch-size",
    "50",
    "--gen-path",
    f"gen/{DATASET_NAME}/DDPM/{index}",
    "--weight-path",    
    f"weights/{DATASET_NAME}/DDPM/{index}.pt"
]

# Run the command and capture the output
try:
    output = subprocess.check_output(command)
    output = output.decode("utf-8")  # Decode the byte output to string (if needed)
    print(output)
except subprocess.CalledProcessError as e:
    print("Command execution failed:", e)


In [10]:
import subprocess

first_number = 1
last_number = 500
frequency = 8  # Specify the frequency of numbers in between

numbers = [first_number] + [first_number + (last_number - first_number) // (frequency + 1) * i for i in range(1, frequency + 1)] + [last_number]

# print(len(epochs))
samples_n = 5000

for index in numbers:
    gen_n = count_image_files(f"gen/{DATASET_NAME}/DDPM/{index}")

    if gen_n < samples_n:
        remain = samples_n - gen_n
        print("Epoch:",index,"Remaining",remain,"images to generate")

        command = [
            "python",
            "modules/ddpm_sample.py",
            "--sample-size",
            f"{remain}",
            "--batch-size",
            "50",
            "--gen-path",
            f"gen/{DATASET_NAME}/DDPM/{index}",
            "--weight-path",
            f"weights/{DATASET_NAME}/DDPM/{index}.pt"
        ]

        try:
            output = subprocess.check_output(command)
            output = output.decode("utf-8")  # Decode the byte output to string (if needed)
            print(output)
        except subprocess.CalledProcessError as e:
            print("Command execution failed:", e)
        
    else:
        print("Epoch:",index,"Done")

Epoch: 1 Done
Epoch: 56 Remaining 5000 images to generate


  return F.conv2d(input, weight, bias, self.stride,
898it [04:43,  3.26it/s]

# Count learnable parameter in model

In [3]:
model = UNet().to("cuda")


pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

In [4]:
print(pytorch_total_params)

22291587
