In [17]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from tqdm import tqdm
import matplotlib.pyplot as plt
import os

# Hyperparameters
img_size = 32
batch_size = 512
max_steps = 1000
lr = 1e-3
epochs = 1000

class CustomModel(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, features=None):
        super(CustomModel, self).__init__()
        if features is None:
            features = [32, 64, 128]

        # Encoder: Downsampling layers with skip connections
        self.encoders = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(in_channels if i == 0 else features[i - 1], features[i], kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(features[i], features[i], kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
            )
            for i in range(len(features))
        ])

        # Pooling layers
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv2d(features[-1], features[-1] * 2, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(features[-1] * 2, features[-1] * 2, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )

        # Decoder: Upsampling layers
        self.decoders = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(features[i], features[i], kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(features[i], features[i], kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
            )
            for i in reversed(range(len(features)))
        ])

        # Upsample layers
        self.upsamples = nn.ModuleList([
            nn.ConvTranspose2d(features[i] * 2 if i == len(features) - 1 else features[i + 1], features[i], kernel_size=2, stride=2)
            for i in reversed(range(len(features)))
        ])

        # Final convolution
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        # Encoder
        for encoder in self.encoders:
            x = encoder(x)
            skip_connections.append(x)
            x = self.pool(x)
            
        # Bottleneck
        x = self.bottleneck(x)
        # Decoder
        skip_connections = skip_connections[::-1]
        for i, (upsample, decoder) in enumerate(zip(self.upsamples, self.decoders)):
            x = upsample(x)  # Upsample
            x = x + skip_connections[i]  # Add skip connection
            x = decoder(x)  # Apply decoder layers

        return self.final_conv(x)

    
def prepare_dataset(batch_size):
    transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
    ])
    # Load FashionMNIST Dataset
    dataset = datasets.FashionMNIST(root="./data", download=True, transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return dataloader

def show_images(images, rows=2, cols=10):
    _, axes = plt.subplots(rows, cols, figsize=(cols, rows))
    for i, ax in enumerate(axes.flat):
        ax.imshow(images[i].cpu().numpy().squeeze(), cmap='gray')
        ax.axis('off')
    plt.show()

def diffusion_process_and_show_images(scheduler, model):
    # Diffusion Process
    model.eval()
    with torch.no_grad():
        samples = torch.randn((20, 1, img_size, img_size), device=device)
        for _, t in tqdm(enumerate(scheduler.timesteps)):
            samples = scheduler.step(model(samples), t, samples).prev_sample
    show_images([sample[0] for sample in samples])        

def save_checkpoint(epoch, model, optimizer, loss_avg):
    checkpoint_path = f"data/FashionMNIST/checkpoint_epoch_{epoch}.pth"
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss_avg,
    }, checkpoint_path)
    print(f"Checkpoint saved at {checkpoint_path}")
        
def load_checkpoint(checkpoint_path, model, optimizer):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    print(f"Checkpoint loaded: Epoch {epoch}, Loss: {loss}")
    return epoch, loss

def plot_losses(losses):
    plt.plot(losses)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.show()

# Select device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

# Prepare dataset
dataloader = prepare_dataset(batch_size)

# Initialize model
#model = UNet2DModel(
#    sample_size=img_size,
#    in_channels=1,
#    out_channels=1,
#    layers_per_block=2,
#    block_out_channels=(32, 64, 128),
#    down_block_types=("DownBlock2D", "DownBlock2D", "AttnDownBlock2D"),
#    up_block_types=("AttnUpBlock2D", "UpBlock2D", "UpBlock2D"),
#).to(device)
# Model Initialization
model = CustomModel(in_channels=1, out_channels=1).to(device)

# Scheduler for Diffusion
scheduler = DDPMScheduler(num_train_timesteps=max_steps)

# Optimizer
optimizer = AdamW(model.parameters(), lr=lr)

start_epoch = 0
# Load Checkpoint path if needed
# checkpoint_path = "data/FashionMNIST/checkpoint_epoch_274.pth"
# if os.path.exists(checkpoint_path):
#    start_epoch, _ = load_checkpoint(checkpoint_path, model, optimizer)

# Training loop
losses = []
for epoch in range(start_epoch, epochs):
    model.train()
    loss_sum = 0.0
    cnt = 0
    for images, lbls in tqdm(dataloader):
        optimizer.zero_grad()
        x = images.to(device)

        # Randam time step
        t = torch.randint(0, max_steps, (len(x),), device=device)  

        # Add noise to images
        noise = torch.randn_like(x)
        noisy_images = scheduler.add_noise(x, noise, t)            

        # Predict noise
        noise_pred = model(noisy_images)

        # Back Propagation
        loss = F.mse_loss(noise_pred, noise)
        loss.backward()
        optimizer.step()
        loss_sum += loss.item()
        cnt += 1

    loss_avg = loss_sum / cnt
    losses.append(loss_avg)
    print(f'Epoch {epoch} | Loss: {loss_avg}')

    save_checkpoint(epoch, model, optimizer, loss_avg)
    plot_losses(losses)

    # Generate and visualize samples with last batch
    diffusion_process_and_show_images(scheduler, model)

# Save model
model.save_pretrained("data/FashionMNIST/")


cuda


  2%|▏         | 2/118 [00:00<00:06, 16.65it/s]

torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])


  3%|▎         | 4/118 [00:00<00:06, 16.80it/s]

torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])


  5%|▌         | 6/118 [00:00<00:06, 16.73it/s]

torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])


  7%|▋         | 8/118 [00:00<00:06, 16.77it/s]

torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])


  8%|▊         | 10/118 [00:00<00:06, 16.81it/s]

torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])


 10%|█         | 12/118 [00:00<00:06, 16.87it/s]

torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])


 12%|█▏        | 14/118 [00:00<00:06, 16.89it/s]

torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])


 14%|█▎        | 16/118 [00:00<00:06, 16.91it/s]

torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])


 15%|█▌        | 18/118 [00:01<00:05, 16.92it/s]

torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])


 17%|█▋        | 20/118 [00:01<00:05, 16.93it/s]

torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])


 19%|█▊        | 22/118 [00:01<00:05, 16.93it/s]

torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])


 20%|██        | 24/118 [00:01<00:05, 16.80it/s]

torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])


 22%|██▏       | 26/118 [00:01<00:05, 16.88it/s]

torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])


 24%|██▎       | 28/118 [00:01<00:05, 16.91it/s]

torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])


 25%|██▌       | 30/118 [00:01<00:05, 16.91it/s]

torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])


 29%|██▉       | 34/118 [00:02<00:05, 14.28it/s]

torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])


 32%|███▏      | 38/118 [00:02<00:05, 15.53it/s]

torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])


 36%|███▌      | 42/118 [00:02<00:04, 16.26it/s]

torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])


 39%|███▉      | 46/118 [00:02<00:04, 16.61it/s]

torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])


 42%|████▏     | 50/118 [00:03<00:04, 16.81it/s]

torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])


 46%|████▌     | 54/118 [00:03<00:03, 16.85it/s]

torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])


 49%|████▉     | 58/118 [00:03<00:03, 16.82it/s]

torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])


 53%|█████▎    | 62/118 [00:03<00:03, 16.76it/s]

torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])


 56%|█████▌    | 66/118 [00:04<00:03, 16.84it/s]

torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])


 59%|█████▉    | 70/118 [00:04<00:02, 16.87it/s]

torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])


 63%|██████▎   | 74/118 [00:04<00:02, 16.87it/s]

torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])


 66%|██████▌   | 78/118 [00:04<00:02, 16.94it/s]

torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])
torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])


 67%|██████▋   | 79/118 [00:04<00:02, 16.40it/s]


torch.Size([512, 256, 4, 4])
torch.Size([512, 128, 8, 8]) torch.Size([512, 128, 8, 8])
torch.Size([512, 64, 16, 16]) torch.Size([512, 64, 16, 16])
torch.Size([512, 32, 32, 32]) torch.Size([512, 32, 32, 32])


KeyboardInterrupt: 