In [1]:
import pickle
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from tqdm import tqdm
from matplotlib import pyplot as plt
import math

import dataparsing

device = torch.device("cuda")

# Initialize wandb and dataset

In [2]:
train_dataset = dataparsing.AugmentedDataset(dataparsing.train, transforms=True)
test_dataset = dataparsing.AugmentedDataset(dataparsing.test, transforms=True)
val_dataset = dataparsing.AugmentedDataset(dataparsing.val, transforms=True)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=True, num_workers=2)

# Define Vae 

In [3]:
cluster_centers = [0, 0, 0, 0]
def get_corner_vectors(dimension):
    corners = []
    corner1 = torch.ones(dimension)
    corner2 = -torch.ones(dimension)

    corner3 = torch.empty(dimension)
    corner4 = torch.empty(dimension)
    for i in range(dimension):
        if i % 2 == 0:
            corner3[i] = 1
            corner4[i] = -1
        else:
            corner3[i] = -1
            corner4[i] = 1
    scale_factor = 5.0
    corners = [corner1 * scale_factor, corner2 * scale_factor, corner3 * scale_factor, corner4 * scale_factor]

    return [c.to(device) for c in corners]
cluster_centers = get_corner_vectors(128)

In [4]:
class SelfAttention(nn.Module):
    def __init__(self, n_heads, embd_dim, in_proj_bias=True, out_proj_bias=True):
        super().__init__()
        self.n_heads = n_heads
        self.in_proj = nn.Linear(embd_dim, 3 * embd_dim, bias=in_proj_bias)
        self.out_proj = nn.Linear(embd_dim, embd_dim, bias=out_proj_bias)
        self.d_heads = embd_dim // n_heads
    def forward(self, x, casual_mask=False):
        batch_size, seq_len, d_embed = x.shape
        interim_shape = (batch_size, seq_len, self.n_heads, self.d_heads)
        q, k, v = self.in_proj(x).chunk(3, dim=-1)
        q = q.view(interim_shape)
        k = k.view(interim_shape)
        v = v.view(interim_shape)
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        weight = q @ k.transpose(-1, -2)
        if casual_mask:
            mask = torch.ones_like(weight, dtype=torch.bool).triu(1)
            weight.masked_fill_(mask, -torch.inf)
        weight /= math.sqrt(self.d_heads)
        weight = F.softmax(weight, dim=-1)
        output = weight @ v
        output = output.transpose(1, 2)
        output = output.reshape((batch_size, seq_len, d_embed))
        output = self.out_proj(output)
        return output

In [5]:
class AttentionBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.groupnorm = nn.GroupNorm(32, channels)
        self.attention = SelfAttention(1, channels)
    def forward(self, x):
        residual = x.clone()
        x = self.groupnorm(x)
        n, c, h, w = x.shape
        x = x.view((n, c, h * w))
        x = x.transpose(-1, -2)
        x = self.attention(x)
        x = x.transpose(-1, -2)
        x = x.view((n, c, h, w))
        x += residual
        return x

In [6]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.groupnorm1 = nn.GroupNorm(32, in_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.groupnorm2 = nn.GroupNorm(32, out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        
        if in_channels == out_channels:
            self.residual_layer = nn.Identity()
        else:
            self.residual_layer = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
    def forward(self, x):
        residue = x.clone()
        x = self.groupnorm1(x)
        x = F.selu(x)
        x = self.conv1(x)
        x = self.groupnorm2(x)
        x = self.conv2(x)
        return x + self.residual_layer(residue)

In [7]:
class Encoder(nn.Sequential):
    def __init__(self):
        super().__init__(
            nn.Conv2d(3, 128, kernel_size=3, padding=1),
            ResidualBlock(128, 128),
            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=0),
            ResidualBlock(128, 256),
            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0),
            ResidualBlock(256, 512),
            nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=0),
            AttentionBlock(512),
            ResidualBlock(512, 512),
            nn.GroupNorm(32, 512),
            nn.SiLU(),
            nn.Conv2d(512, 8, kernel_size=3, padding=1),
            nn.Conv2d(8, 8, kernel_size=1, padding=0)
        )
    def forward(self, x):
        for module in self:
            if isinstance(module, nn.Conv2d) and module.stride == (2, 2):
                x = F.pad(x, (0, 1, 0, 1))
            x = module(x)
        
        return x

In [8]:
class Decoder(nn.Sequential):
    def __init__(self):
        super().__init__(
            nn.Conv2d(4, 512, kernel_size=3, padding=1),
            ResidualBlock(512, 512),
            AttentionBlock(512),
            ResidualBlock(512, 512),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            ResidualBlock(512, 512),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            ResidualBlock(512, 256),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            ResidualBlock(256, 128),
            nn.GroupNorm(32, 128),
            nn.SiLU(),
            nn.Conv2d(128, 3, kernel_size=3, padding=1)
        )
    def forward(self, x):
        x /= 0.18215
        for module in self:
            x = module(x)
        return x

In [9]:
class VAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.mu = nn.Linear(512, 128)
        self.logvar = nn.Linear(512, 128)
        self.restarter = nn.Linear(128, 256)
    def forward(self, x):
        encoded = self.encoder(x)
        encoded = encoded.view(encoded.size(0), -1)
        mean = self.mu(encoded)
        log_variance = self.logvar(encoded)
        log_variance = torch.clamp(log_variance, -30, 20)
        std = torch.exp(0.5 * log_variance)
        eps = torch.randn_like(std)
        mean = mean + eps * std
        restarted = self.restarter(mean)
        restarted = restarted.view(restarted.size(0), 4, 8, 8)
        decoded = self.decoder(restarted)
        return decoded, mean, log_variance


In [10]:
vae = VAE()
vae.to(device)
optimizer = torch.optim.AdamW(vae.parameters(), lr=1e-3)

# Define Loss

In [11]:
def loss_function(mu, y, recon_x, x):
    # mu: [batch_size, latent_dim], y: [batch_size] with values in [0..L-1]
    centers = torch.stack(cluster_centers).to(mu.device)        # [L, latent_dim]
    distances = torch.cdist(mu, centers)                       # [batch_size, L]

    # invert distances to get logits; add eps to avoid div0
    total_dist = distances.sum(dim=1, keepdim=True)            # [batch_size, 1]
    logits = (total_dist - distances) / (total_dist + 1e-8)    # [batch_size, L]

    loss = F.cross_entropy(logits, y)
    preds = logits.argmax(dim=1)
    acc  = (preds == y).float().mean()
    mse_loss = F.mse_loss(recon_x, x, reduction='sum')
    loss += mse_loss * 0.00001
    return loss, acc

# Function to sample anchor, positive, negative triplets within a batch
def sample_triplets(mu, labels):
    anchors, positives, negatives = [], [], []
    device = mu.device  # Get the device of the mu tensor
    for i in range(mu.size(0)):
        anchor, label = mu[i], labels[i]
        # Ensure the index tensor is on the same device as mu
        positive_indices = (labels == label) & (torch.arange(mu.size(0), device=device) != i)
        negative_indices = labels != label
        positive = mu[positive_indices]
        negative = mu[negative_indices]

        if len(positive) > 0 and len(negative) > 0:
            pos_sample = positive[torch.randint(len(positive), (1,), device=device)]
            neg_sample = negative[torch.randint(len(negative), (1,), device=device)]
            anchors.append(anchor)
            positives.append(pos_sample.squeeze(0))
            negatives.append(neg_sample.squeeze(0))
    if len(anchors) > 0:
        return torch.stack(anchors), torch.stack(positives), torch.stack(negatives)
    else:
        return None, None, None

# Training Loop

In [12]:
for epoch in range(40):
    vae.train()
    running_loss = 0.0
    running_acc = 0.0
    total_samples = 0


    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1} - Training"):
        inputs = batch[0].to(device)
        labels = batch[1].to(device) # Get labels for loss calculation
        optimizer.zero_grad()
        recon_x, mu, logvar = vae(inputs)
        # Use the fixed cluster_centers in the loss function
        loss, accuracy = loss_function(mu, labels, recon_x, inputs)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        running_acc += accuracy * inputs.size(0)
        total_samples += inputs.size(0)


    avg_train_loss = running_loss / total_samples

    # Validation loop
    vae.eval()
    running_val_loss = 0.0
    running_acc_eval = 0.0
    total_val_samples = 0

    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Epoch {epoch+1} - Validation"):
            inputs = batch[0].to(device)
            labels = batch[1].to(device) # Get labels for loss calculation
            recon_x, mu, logvar = vae(inputs)
            # Use the fixed cluster_centers in the loss function
            loss, acc = loss_function(mu, labels, recon_x, inputs)
            running_val_loss += loss.item() * inputs.size(0)
            running_acc_eval += acc * inputs.size(0)
            total_val_samples += inputs.size(0)

    avg_val_loss = running_val_loss / total_val_samples


    # Save a sample input and its reconstruction as an image
    vae.eval()
    sample_batch = next(iter(val_loader))
    sample_input = sample_batch[0].to(device)
    with torch.no_grad():
        sample_output, _, _ = vae(sample_input)
    sample_input = sample_input.cpu().numpy()
    sample_output = sample_output.cpu().numpy()

    num_samples = min(20, sample_input.shape[0])
    fig, axs = plt.subplots(2, num_samples, figsize=(num_samples * 2, 4))
    for i in range(num_samples):
        # Input image
        axs[0, i].imshow(sample_input[i].transpose(1, 2, 0))
        axs[0, i].axis("off")
        # Reconstructed image
        axs[1, i].imshow(sample_output[i].transpose(1, 2, 0))
        axs[1, i].axis("off")
    plt.savefig(f"vae_reconstructions/epoch_{epoch+1}.png")
    plt.close(fig)

    print(f"Train Loss: {avg_train_loss}, Train Acc: {running_acc / total_samples}, Val Loss: {avg_val_loss}, Val Acc: {running_acc_eval / total_val_samples}")

    torch.save(vae, 'VAEModelCubicFit.pth')

Epoch 1 - Training: 100%|██████████| 1819/1819 [05:20<00:00,  5.68it/s]
Epoch 1 - Validation: 100%|██████████| 395/395 [00:34<00:00, 11.61it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.019793559..1.0267051].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.01352371..0.9410879].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.015502038..0.89859116].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.019507084..0.95978993].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.009365729..0.88491064].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range

Train Loss: 1.3718013155009738, Train Acc: 0.5398625135421753, Val Loss: 1.3673725184472352, Val Acc: 0.5353407263755798


Epoch 2 - Training: 100%|██████████| 1819/1819 [05:20<00:00,  5.67it/s]
Epoch 2 - Validation: 100%|██████████| 395/395 [00:34<00:00, 11.52it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.012777794..0.9806065].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.014203019..0.5395645].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.0038332548..1.0283058].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.006150985..0.882353].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.006701395..1.007824].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.0

Train Loss: 1.350092654703409, Train Acc: 0.581134021282196, Val Loss: 1.3501503845024412, Val Acc: 0.5793977975845337


Epoch 3 - Training: 100%|██████████| 1819/1819 [05:21<00:00,  5.66it/s]
Epoch 3 - Validation: 100%|██████████| 395/395 [00:34<00:00, 11.49it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.013369761..0.7311668].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.0053419396..0.9263864].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.021733388..0.7800881].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.0031787157..0.86840236].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.027555563..1.0931821].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range

Train Loss: 1.3441262072795854, Train Acc: 0.599518895149231, Val Loss: 1.3431596329775175, Val Acc: 0.5860539078712463


Epoch 4 - Training: 100%|██████████| 1819/1819 [05:20<00:00,  5.67it/s]
Epoch 4 - Validation: 100%|██████████| 395/395 [00:34<00:00, 11.53it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.028914787..0.8280055].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.011190504..0.86688894].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.04795467..0.89324653].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.020062566..0.96107095].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.012930527..0.8229044].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range

Train Loss: 1.3381365344614506, Train Acc: 0.6115463972091675, Val Loss: 1.3397066224773788, Val Acc: 0.6009508967399597


Epoch 5 - Training: 100%|██████████| 1819/1819 [05:20<00:00,  5.67it/s]
Epoch 5 - Validation: 100%|██████████| 395/395 [00:34<00:00, 11.56it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.020325687..1.0011487].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.023981832..0.5878866].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.017531224..1.0042566].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.023813091..1.0111263].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.003814824..0.80205345].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.

Train Loss: 1.3334470724649856, Train Acc: 0.6219587326049805, Val Loss: 1.3361115753745125, Val Acc: 0.5996830463409424


Epoch 6 - Training: 100%|██████████| 1819/1819 [05:20<00:00,  5.68it/s]
Epoch 6 - Validation: 100%|██████████| 395/395 [00:34<00:00, 11.56it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.002039045..0.8647778].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.02087468..0.6410583].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.01790478..0.8237612].


Train Loss: 1.3282045903156714, Train Acc: 0.6354638934135437, Val Loss: 1.3322064587884772, Val Acc: 0.6110935211181641


Epoch 7 - Training: 100%|██████████| 1819/1819 [05:19<00:00,  5.70it/s]
Epoch 7 - Validation: 100%|██████████| 395/395 [00:34<00:00, 11.61it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.011555396..0.9640763].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.013068296..0.9629792].


Train Loss: 1.3227982569068568, Train Acc: 0.6477319598197937, Val Loss: 1.3236953066948287, Val Acc: 0.6288431286811829


Epoch 8 - Training: 100%|██████████| 1819/1819 [05:19<00:00,  5.69it/s]
Epoch 8 - Validation: 100%|██████████| 395/395 [00:34<00:00, 11.54it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.032979302..1.0534216].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.0043025017..1.005787].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.029961102..1.0225594].


Train Loss: 1.3176549405494506, Train Acc: 0.6526460647583008, Val Loss: 1.3232243110563033, Val Acc: 0.6342313885688782


Epoch 9 - Training: 100%|██████████| 1819/1819 [05:20<00:00,  5.68it/s]
Epoch 9 - Validation: 100%|██████████| 395/395 [00:34<00:00, 11.55it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.024895616..1.2741402].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.009127207..1.306662].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.0042705685..1.0647106].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.006073728..1.338737].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.014757801..1.0049006].


Train Loss: 1.3121635159266363, Train Acc: 0.6634708046913147, Val Loss: 1.3155067799775233, Val Acc: 0.6437401175498962


Epoch 10 - Training: 100%|██████████| 1819/1819 [05:19<00:00,  5.69it/s]
Epoch 10 - Validation: 100%|██████████| 395/395 [00:34<00:00, 11.60it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.024719529..1.1646498].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.02358777..1.252115].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.011679336..0.73216426].


Train Loss: 1.3088364850368697, Train Acc: 0.6704810857772827, Val Loss: 1.3100061729101673, Val Acc: 0.6573692560195923


Epoch 11 - Training: 100%|██████████| 1819/1819 [05:19<00:00,  5.69it/s]
Epoch 11 - Validation: 100%|██████████| 395/395 [00:34<00:00, 11.58it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.02430439..0.9623454].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.0029982328..0.7880899].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.03037244..0.8969172].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.0019244403..0.8900081].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.008131698..1.1062369].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range 

Train Loss: 1.3050435898639903, Train Acc: 0.67608243227005, Val Loss: 1.3210487396327895, Val Acc: 0.6247226595878601


Epoch 12 - Training: 100%|██████████| 1819/1819 [05:19<00:00,  5.70it/s]
Epoch 12 - Validation: 100%|██████████| 395/395 [00:33<00:00, 11.62it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.010359261..1.0823208].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.016589023..0.94241583].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.013372526..1.0482416].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.004692219..1.0774881].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.0022822246..1.2367085].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range 

Train Loss: 1.303055031193081, Train Acc: 0.6791752576828003, Val Loss: 1.3149051372298348, Val Acc: 0.632012665271759


Epoch 13 - Training: 100%|██████████| 1819/1819 [05:19<00:00,  5.70it/s]
Epoch 13 - Validation: 100%|██████████| 395/395 [00:33<00:00, 11.66it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.016596891..1.0231122].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.0042883605..0.9620465].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.02961132..1.0621701].


Train Loss: 1.3030179630194334, Train Acc: 0.6781099438667297, Val Loss: 1.3024774609579337, Val Acc: 0.6787638664245605


Epoch 14 - Training: 100%|██████████| 1819/1819 [05:19<00:00,  5.70it/s]
Epoch 14 - Validation: 100%|██████████| 395/395 [00:33<00:00, 11.69it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.0040461496..0.9168049].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.018217944..0.8744639].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.005465299..1.1249925].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.0004730299..0.9271643].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.0026982203..0.96507275].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got r

Train Loss: 1.3001201791206176, Train Acc: 0.6874226927757263, Val Loss: 1.302420350715589, Val Acc: 0.6651347279548645


Epoch 15 - Training: 100%|██████████| 1819/1819 [05:18<00:00,  5.70it/s]
Epoch 15 - Validation: 100%|██████████| 395/395 [00:33<00:00, 11.65it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.028359015..1.0472072].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.024582699..1.1399509].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.025986299..1.1147372].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.024438553..1.0212569].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.027710602..1.0687745].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.0

Train Loss: 1.2977869590123494, Train Acc: 0.6885223388671875, Val Loss: 1.2955660300851812, Val Acc: 0.6976228356361389


Epoch 16 - Training: 100%|██████████| 1819/1819 [05:19<00:00,  5.70it/s]
Epoch 16 - Validation: 100%|██████████| 395/395 [00:34<00:00, 11.60it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.021806344..1.0777278].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.02391868..1.0385299].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.02804067..1.0321066].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.017574526..1.1636724].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.0147507265..1.1600041].


Train Loss: 1.2949812211531544, Train Acc: 0.6965292096138, Val Loss: 1.3029148317932515, Val Acc: 0.6645007729530334


Epoch 17 - Training: 100%|██████████| 1819/1819 [05:19<00:00,  5.70it/s]
Epoch 17 - Validation: 100%|██████████| 395/395 [00:33<00:00, 11.62it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.0087580085..1.0688579].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.01920963..1.1354301].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.015854836..1.1240218].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.020792961..1.0064476].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.012388542..1.090394].


Train Loss: 1.2956144052682463, Train Acc: 0.6940549612045288, Val Loss: 1.2956167335177753, Val Acc: 0.6847860813140869


Epoch 18 - Training: 100%|██████████| 1819/1819 [05:19<00:00,  5.69it/s]
Epoch 18 - Validation: 100%|██████████| 395/395 [00:34<00:00, 11.55it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.012693003..1.257998].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.024628893..1.1089541].


Train Loss: 1.293179385375321, Train Acc: 0.69924396276474, Val Loss: 1.3013622514797276, Val Acc: 0.6591125130653381


Epoch 19 - Training: 100%|██████████| 1819/1819 [05:20<00:00,  5.68it/s]
Epoch 19 - Validation: 100%|██████████| 395/395 [00:34<00:00, 11.59it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.027288213..1.0400431].


Train Loss: 1.2930468230886558, Train Acc: 0.6981787085533142, Val Loss: 1.2915120324317702, Val Acc: 0.7074484825134277


Epoch 20 - Training: 100%|██████████| 1819/1819 [05:18<00:00,  5.70it/s]
Epoch 20 - Validation: 100%|██████████| 395/395 [00:33<00:00, 11.72it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.032593235..1.0515723].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.027128942..1.0910408].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.021034546..1.041328].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.01889012..1.1051387].


Train Loss: 1.2906151928852514, Train Acc: 0.7025772929191589, Val Loss: 1.2926142225930506, Val Acc: 0.6906497478485107


Epoch 21 - Training: 100%|██████████| 1819/1819 [05:18<00:00,  5.70it/s]
Epoch 21 - Validation: 100%|██████████| 395/395 [00:33<00:00, 11.63it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.024352878..1.0608157].


Train Loss: 1.291559096693583, Train Acc: 0.6988316178321838, Val Loss: 1.3178574593055834, Val Acc: 0.6291600465774536


Epoch 22 - Training: 100%|██████████| 1819/1819 [05:19<00:00,  5.70it/s]
Epoch 22 - Validation: 100%|██████████| 395/395 [00:34<00:00, 11.59it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.015895873..1.0167497].


Train Loss: 1.2899359806460613, Train Acc: 0.7059793472290039, Val Loss: 1.288505460985489, Val Acc: 0.7055467367172241


Epoch 23 - Training: 100%|██████████| 1819/1819 [05:19<00:00,  5.69it/s]
Epoch 23 - Validation: 100%|██████████| 395/395 [00:33<00:00, 11.65it/s]


Train Loss: 1.2884615975638845, Train Acc: 0.706804096698761, Val Loss: 1.299854826624911, Val Acc: 0.6524564027786255


Epoch 24 - Training: 100%|██████████| 1819/1819 [05:19<00:00,  5.69it/s]
Epoch 24 - Validation: 100%|██████████| 395/395 [00:33<00:00, 11.65it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.031568117..1.017532].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.02839154..1.0066992].


Train Loss: 1.287456279505569, Train Acc: 0.7077662944793701, Val Loss: 1.2902550299080728, Val Acc: 0.6976228356361389


Epoch 25 - Training: 100%|██████████| 1819/1819 [05:19<00:00,  5.70it/s]
Epoch 25 - Validation: 100%|██████████| 395/395 [00:34<00:00, 11.60it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.0034249425..0.9342385].


Train Loss: 1.2871981182786607, Train Acc: 0.7064604759216309, Val Loss: 1.2950851180851932, Val Acc: 0.678288459777832


Epoch 26 - Training: 100%|██████████| 1819/1819 [05:19<00:00,  5.70it/s]
Epoch 26 - Validation: 100%|██████████| 395/395 [00:34<00:00, 11.58it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.021980248..1.0011039].


Train Loss: 1.2853202187974018, Train Acc: 0.7113058567047119, Val Loss: 1.3021159743733717, Val Acc: 0.6385102868080139


Epoch 27 - Training: 100%|██████████| 1819/1819 [05:19<00:00,  5.70it/s]
Epoch 27 - Validation: 100%|██████████| 395/395 [00:33<00:00, 11.64it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.008517966..1.0241716].


Train Loss: 1.286870535624396, Train Acc: 0.7070790529251099, Val Loss: 1.2913737291012626, Val Acc: 0.689540445804596


Epoch 28 - Training: 100%|██████████| 1819/1819 [05:19<00:00,  5.69it/s]
Epoch 28 - Validation: 100%|██████████| 395/395 [00:34<00:00, 11.56it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.026625477..1.0342276].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.017549843..0.90348685].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.00845851..0.6320876].


Train Loss: 1.2860530628289553, Train Acc: 0.7093126773834229, Val Loss: 1.2992936889645414, Val Acc: 0.6510301232337952


Epoch 29 - Training: 100%|██████████| 1819/1819 [05:19<00:00,  5.69it/s]
Epoch 29 - Validation: 100%|██████████| 395/395 [00:34<00:00, 11.52it/s]


Train Loss: 1.28415377017149, Train Acc: 0.7146391868591309, Val Loss: 1.2865140529895547, Val Acc: 0.7099841833114624


Epoch 30 - Training: 100%|██████████| 1819/1819 [05:19<00:00,  5.69it/s]
Epoch 30 - Validation: 100%|██████████| 395/395 [00:34<00:00, 11.60it/s]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.03312549..0.9654236].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.0036900267..1.0147707].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.010211468..1.0180722].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.010369092..0.93442184].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.0061763376..0.6185344].


Train Loss: 1.2818181074607824, Train Acc: 0.722542941570282, Val Loss: 1.2880555475948352, Val Acc: 0.7028526067733765


Epoch 31 - Training:  62%|██████▏   | 1125/1819 [03:21<02:04,  5.58it/s]


KeyboardInterrupt: 