In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from abc import abstractmethod

In [None]:
from typing import List, Callable, Union, Any, TypeVar, Tuple
# from torch import tensor as Tensor

Tensor = TypeVar('torch.tensor')

In [None]:
class BaseVAE(nn.Module):
    
    def __init__(self) -> None:
        super(BaseVAE, self).__init__()

    def encode(self, input: Tensor) -> List[Tensor]:
        raise NotImplementedError

    def decode(self, input: Tensor) -> Any:
        raise NotImplementedError

    def sample(self, batch_size:int, current_device: int, **kwargs) -> Tensor:
        raise NotImplementedError

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        raise NotImplementedError

    @abstractmethod
    def forward(self, *inputs: Tensor) -> Tensor:
        pass

    @abstractmethod
    def loss_function(self, *inputs: Any, **kwargs) -> Tensor:
        pass

In [None]:
class VectorQuantizer(nn.Module):
    """
    Reference:
    [1] https://github.com/deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py
    """
    def __init__(self,
                 num_embeddings: int,
                 embedding_dim: int,
                 beta: float = 0.25,
                 use_ema = False,
                 decay=0.99,
                 epsilon=1e-5):
        super(VectorQuantizer, self).__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.beta = beta

        self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
        
        self.use_ema = use_ema
        
        if self.use_ema:
            self.embedding.weight.data.normal_()
            self.register_buffer('_ema_cluster_size', torch.zeros(self.num_embeddings))
            self.ema_w = nn.Parameter(torch.Tensor(self.num_embeddings, self.embedding_dim))
            self.ema_w.data.normal_()

            self.decay = decay
            self.epsilon = epsilon
        else:
            self.embedding.weight.data.uniform_(-1 / self.num_embeddings, 1 / self.num_embeddings)
        

    def forward(self, latents: Tensor) -> Tensor:
        latents = latents.permute(0, 2, 3, 1).contiguous()  # [B x D x H x W] -> [B x H x W x D]
        latents_shape = latents.shape
        flat_latents = latents.view(-1, self.embedding_dim)  # [BHW x D]

        # Compute L2 distance between latents and embedding weights
        dist = torch.sum(flat_latents ** 2, dim=1, keepdim=True) + \
               torch.sum(self.embedding.weight ** 2, dim=1) - \
               2 * torch.matmul(flat_latents, self.embedding.weight.t())  # [BHW x K]

        # Get the encoding that has the min distance
        encoding_inds = torch.argmin(dist, dim=1).unsqueeze(1)  # [BHW, 1]

        # Convert to one-hot encodings
        device = latents.device
        encoding_one_hot = torch.zeros(encoding_inds.size(0), self.num_embeddings, device=device)
        encoding_one_hot.scatter_(1, encoding_inds, 1)  # [BHW x K]
        
        # Use EMA to update the embedding vectors
        if self.use_ema:
            self._ema_cluster_size = self._ema_cluster_size * self.decay + \
                (1 - self.decay) * torch.sum(encoding_one_hot, 0)
    
            n = torch.sum(self._ema_cluster_size.data)
            self._ema_cluster_size = (
                (self._ema_cluster_size + self.epsilon)
                / (n + self.num_embeddings * self.epsilon) * n
            )
            
            dw = torch.matmul(encoding_one_hot.t(), flat_latents)
            self.ema_w = nn.Parameter(self.ema_w * self.decay + (1 - self.decay) * dw)
        
            self.embedding.weight = nn.Parameter(self.ema_w / self._ema_cluster_size.unsqueeze(1))
        
        # Quantize the latents
        quantized_latents = torch.matmul(encoding_one_hot, self.embedding.weight)  # [BHW, D]
        quantized_latents = quantized_latents.view(latents_shape)  # [B x H x W x D]

        # Compute the VQ Losses
        commitment_loss = F.mse_loss(quantized_latents.detach(), latents)
        
        if self.use_ema:
            vq_loss = commitment_loss * self.beta
        else:
            embedding_loss = F.mse_loss(quantized_latents, latents.detach())
            vq_loss = commitment_loss * self.beta + embedding_loss

        # Add the residue back to the latents
        quantized_latents = latents + (quantized_latents - latents).detach()

        avg_probs = torch.mean(encoding_one_hot, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
        
        return vq_loss, perplexity, quantized_latents.permute(0,3,1,2).contiguous(), encoding_one_hot

In [None]:
class ResidualLayer(nn.Module):

    def __init__(self,
                 in_channels: int,
                 out_channels: int):
        super(ResidualLayer, self).__init__()
        self.resblock = nn.Sequential(nn.Conv2d(in_channels, out_channels,
                                                kernel_size=3, padding=1, bias=False),
                                      nn.ReLU(True),
                                      nn.Conv2d(out_channels, out_channels,
                                                kernel_size=1, bias=False))

    def forward(self, input: Tensor) -> Tensor:
        return input + self.resblock(input)

In [None]:
class VQVAE(BaseVAE):

    def __init__(self,
                 in_channels: int,
                 embedding_dim: int,
                 num_embeddings: int,
                 use_ema: True,
                 hidden_dims: List = None,
                 beta: float = 0.25,
                 img_size: int = 32,
                 **kwargs) -> None:
        super(VQVAE, self).__init__()

        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.img_size = img_size
        self.beta = beta

        modules = []
        if hidden_dims is None:
            hidden_dims = [128, 256]

        # Build Encoder
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim,
                              kernel_size=4, stride=2, padding=1),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        modules.append(
            nn.Sequential(
                nn.Conv2d(in_channels, in_channels,
                          kernel_size=3, stride=1, padding=1),
                nn.LeakyReLU())
        )

        for _ in range(6):
            modules.append(ResidualLayer(in_channels, in_channels))
        modules.append(nn.LeakyReLU())

        modules.append(
            nn.Sequential(
                nn.Conv2d(in_channels, embedding_dim,
                          kernel_size=1, stride=1),
                nn.LeakyReLU())
        )

        self.encoder = nn.Sequential(*modules)

        self.vq_layer = VectorQuantizer(num_embeddings,
                                        embedding_dim,
                                        self.beta,
                                        use_ema)

        # Build Decoder
        modules = []
        modules.append(
            nn.Sequential(
                nn.Conv2d(embedding_dim,
                          hidden_dims[-1],
                          kernel_size=3,
                          stride=1,
                          padding=1),
                nn.LeakyReLU())
        )

        for _ in range(6):
            modules.append(ResidualLayer(hidden_dims[-1], hidden_dims[-1]))

        modules.append(nn.LeakyReLU())

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=4,
                                       stride=2,
                                       padding=1),
                    nn.LeakyReLU())
            )

        modules.append(
            nn.Sequential(
                nn.ConvTranspose2d(hidden_dims[-1],
                                   out_channels=3,
                                   kernel_size=4,
                                   stride=2, padding=1),
                nn.Tanh()))

        self.decoder = nn.Sequential(*modules)

    def encode(self, input: Tensor) -> List[Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x C x H x W]
        :return: (Tensor) List of latent codes
        """
        result = self.encoder(input)
        return [result]

    def decode(self, z: Tensor) -> Tensor:
        """
        Maps the given latent codes
        onto the image space.
        :param z: (Tensor) [B x D x H x W]
        :return: (Tensor) [B x C x H x W]
        """

        result = self.decoder(z)
        return result
    
    def vq(self, encoding: List[Tensor]): 
        return self.vq_layer(encoding)

    def forward(self, inputs: Tensor, **kwargs) -> List[Tensor]:
        encoding = self.encode(inputs)[0]
        vqloss, perplexity, quantized_inputs, encoding_one_hot = self.vq(encoding)
        reconstructed = self.decode(quantized_inputs)
        return reconstructed, vqloss, perplexity

    def sample(self,
               num_samples: int,
               current_device: Union[int, str], **kwargs) -> Tensor:
        raise Warning('VQVAE sampler is not implemented.')

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        """
        Given an input image x, returns the reconstructed image
        :param x: (Tensor) [B x C x H x W]
        :return: (Tensor) [B x C x H x W]
        """

        return self.forward(x)[0]

In [None]:
import torchvision
import torchvision.transforms as transforms

image_size = 32
# Define a transform to normalize the data
# Load the CIFAR-10 dataset
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.Resize(image_size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Download the training set
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)

# Create a DataLoader for the training set
batch_size = 64
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

# Download the test set
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

# Create a DataLoader for the test set
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Function to unnormalize and display an image
def imshow(img):
    img = img / 2 + 0.5  # Unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

# Define a transform to reverse the normalization
reverse_transform = transforms.Compose([
    transforms.Normalize((-0.5, -0.5, -0.5), (1.0, 1.0, 1.0)),
    transforms.ToPILImage()
])

# Get some random training images
dataiter = iter(trainloader)
images, labels = next(dataiter)

# Show images
imshow(torchvision.utils.make_grid(images))

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
train_data_variance = np.var(trainset.data / 255.0)

In [None]:
import wandb

In [None]:
vqvae_cfg = {
    'epochs':25,
    'in_channels':3,
    'num_embed':512,
    'embed_dim':64,
    'lr':3e-4,
    'use_ema':True
}

wandb.init(project='bagel',group='vqvae',config=vqvae_cfg)

In [None]:
from tqdm import tqdm
import torch.optim as optim
# Instantiate the VQ-VAE model
model = VQVAE(in_channels=3, num_embeddings=512, embedding_dim=64, use_ema=True).to(device)
model.train()

# Define the loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=3e-4)

train_losses = []
recon_losses = []
vq_losses = []
ppls = []
LOSS = float('inf')

# Training loop
num_epochs = 25

for epoch in tqdm(range(num_epochs)):
    for data in trainloader:
        inputs, _ = data
        inputs = inputs.to(device)
        optimizer.zero_grad()

        # Forward pass
        reconstructed, vq_loss, _ = model(inputs)

        # Compute the loss
        recon_loss = criterion(reconstructed, inputs) 
        loss = recon_loss + vq_loss

        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        
    train_losses.append(loss.item())
    recon_losses.append(recon_loss.item())
    vq_losses.append(vq_loss.item())
    
    # Save the trained model
    if loss < LOSS:
        torch.save({'state_dict': model.state_dict(), 'epoch': epoch, 'loss': loss}, 'vq_vae_model_ema.pth')
    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item()}')
    
    wandb.log({'vqvae_loss':loss.item(), 'epoch':epoch})

In [None]:
# Use the trained VQ-VAE as an image tokenizer
# You can use the encoder part of the model to obtain codes for input images
# Example:
# 
# model.eval()

In [None]:
# vqvae_model.eval()
# with torch.no_grad():
#     for data in trainloader:
#         inputs, _ = data
#         inputs = inputs.to(device)
#         encoded_inputs = vqvae_model.encode(inputs)[0]
#         for param in vqvae_model.parameters():
#             print(param.requires_grad)
#         _, _, quantized_inputs, _ = vqvae_model.vq(encoded_inputs)
#     #     flattened_codes = quantized_inputs.view(inputs.size(0), -1, inputs.dim)
#         flattened_codes = quantized_inputs.flatten(2).transpose(1, 2)
#         vq_dim = flattened_codes.size(1)
#         project_vs = nn.Linear(vq_dim, edim).to(device)
#         flattened_codes = project_vs(flattened_codes)

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, image_size, patch_size, in_channels, embed_dim):
      # TODO

        super(PatchEmbedding, self).__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.embed_dim = embed_dim
        self.num_patches = (self.image_size // self.patch_size) ** 2
        self.proj = nn.Conv2d(self.in_channels, 
                            self.embed_dim, 
                            kernel_size=self.patch_size, 
                            stride=self.patch_size
                           )
        self.norm = nn.LayerNorm(self.embed_dim)

    def forward(self, x):
        # TODO
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)
        x = self.norm(x)

        return x

In [None]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        # TODO
        super(MultiHeadSelfAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = self.embed_dim // self.num_heads
        self.scale = self.head_dim ** -0.5
        self.qkv = nn.Linear(self.embed_dim, self.embed_dim * 3)
        self.qk_norm = False
        self.use_activation = False
        self.activation = nn.ReLU() if self.use_activation else nn.Identity()
        self.q_norm = nn.LayerNorm(self.head_dim) if self.qk_norm else nn.Identity()
        self.k_norm = nn.LayerNorm(self.head_dim) if self.qk_norm else nn.Identity()
        self.attn_dropout = nn.Dropout(0.1)
        self.proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.proj_dropout = nn.Dropout(0)

    def forward(self, x):
        # TODO
        batch_si, seq_len, emb_dim = x.shape
        qkv = self.qkv(x).reshape(batch_si, seq_len, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        q, k = self.q_norm(q), self.k_norm(k)

        q = q * self.scale
        attention = q @ k.transpose(-2, -1)
        attention = attention.softmax(dim=-1)
        attention = self.attn_dropout(attention)

        z = attention @ v
        z = z.transpose(1, 2).reshape(batch_si, seq_len, emb_dim)
        z = self.proj(z)
#         z = self.proj_dropout(z)
        return z

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_dim, dropout):
        # TODO
        super(TransformerBlock, self).__init__()
        self.attention = MultiHeadSelfAttention(embed_dim, num_heads)
        self.attention_norm = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(nn.Linear(embed_dim, mlp_dim),
                                  nn.GELU(),
                                  nn.Dropout(dropout),
                                  nn.Linear(mlp_dim, embed_dim)
                                  # nn.Dropout(dropout)
        )
        self.mlp_norm = nn.LayerNorm(embed_dim)


    def forward(self, x):
        # TODO
        res = x
        x = self.attention_norm(x)
        x = self.attention(x)
        x = x + res # residual connection
        res = x
        x = self.mlp_norm(x)
        x = self.mlp(x)
        x = x + res
        return x

In [None]:
class VitWithVQVAE(nn.Module):
    def __init__(self, image_size, patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout, mode, vq_dim=None):
        # TODO
        super(VitWithVQVAE, self).__init__()
        self.mode = mode
        if self.mode == 'patch':
            self.patch_embed = PatchEmbedding(image_size, patch_size, in_channels, embed_dim)
            self.embed_len = self.patch_embed.num_patches + 1
        if self.mode == 'vqvae':
            self.conv_dilate = 1
            self.conv_kernel = 4
            self.conv_stride = 4
            self.vqvae = vqvae_model
            self.vq_dim = vq_dim
            self.project_vs = nn.Linear(vq_dim, embed_dim)
            self.embed_len = int(((image_size - self.conv_dilate * (self.conv_kernel - 1) - 1) / self.conv_stride + 1) ** 2) + 1  # max seqlen
        if self.mode == "conv":
            self.conv_dilate = 1
            self.conv_kernel = 5
            self.conv_stride = 3
            self.conv2d = nn.Conv2d(in_channels,
                                    embed_dim * 1,
                                    5,
                                    groups=1,
                                    stride=3,
                                    dilation=1)
            self.proj_feats = nn.Linear(embed_dim * 1, embed_dim)
            self.feat_norm = nn.LayerNorm(embed_dim)
            self.embed_len = int(((image_size - self.conv_dilate * (self.conv_kernel - 1) - 1) / self.conv_stride + 1) ** 2) + 1  # max seqlen

        
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_len, embed_dim))
        self.dropout = nn.Dropout(dropout)
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_dim, dropout) for i in range(num_layers)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.cls_head = nn.Sequential(nn.Linear(embed_dim, embed_dim//2),
                                nn.GELU(),
                                nn.Dropout(dropout),
                                nn.Linear(embed_dim // 2, num_classes),
                                # nn.Dropout(dropout)     
                                )                           

    def forward(self, x):
        # TODO
        if self.mode == "patch":
            x = self.patch_embed(x)
#             print("patch_embed",x.shape)
        if self.mode == "vqvae":
            self.vqvae.eval()
            with torch.no_grad():
                z = self.vqvae.encode(x)[0]
                _, _, quantized_z, _ = vqvae_model.vq(x)
#                 print(quantized_z.shape)
                flattened_z = quantized_z.flatten(2).transpose(1, 2)
#                 print(flattened_z.shape)
                x = self.project_vs(flattened_z)
#                 x = self.dropout(x)
#                 print("proj vs shape", x.shape)

        if self.mode == "conv":
            x = self.conv2d(x)
#             print(x.shape)
            vqvae_model.eval()
            with torch.no_grad():
                _, _, quantized_z, _ = vqvae_model.vq(x)
#             print(quantized_z.shape)
            flattened_z = quantized_z.flatten(2).transpose(1, 2)
#             print(flattened_z.shape)
            x = self.proj_feats(flattened_z)
#             print(x.shape)
#             x = self.dropout(x)
#             x = self.feat_norm(x)

        x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
#         print("concat cls tok", x.shape)
#         print(self.embed_len)
        x = x + self.pos_embed[:, -x.size(1):, :]
#         print("add pos embed", x.shape)
        # x = self.dropout(x)
        for block in self.transformer_blocks:
            x = block(x)
#             print("trans block", x.shape)
        # x = self.norm(x)
        logits = self.cls_head(x[:, 0])
        
#         print("logits shape", logits.shape)

        return logits

In [None]:
# Example usage:
image_size = 32
patch_size = 4
in_channels = 3
embed_dim = 256 # 512
num_heads = 4
mlp_dim = 1024
num_layers = 6 # 4
num_classes = 10
dropout = 0.1
batch_size = 128 # 256
mode = "patch"

In [None]:
'/kaggle/input/vq-vae-model-pth/vq_vae_model.pth'

In [None]:
vq_dim = 64
path = '/kaggle/input/vq-vae-model-pth/'
vqvae_model = VQVAE(in_channels=3, num_embeddings=512, embedding_dim=64, use_ema=False).to(device)
vqvae_state_dict = torch.load(path+'vq_vae_model.pth')['state_dict']
vqvae_model.load_state_dict(vqvae_state_dict)
vqvae_total_params = sum(param.numel() for param in vqvae_model.parameters())
vqvae_total_params

In [None]:
vqvae_model

In [None]:
model = VitWithVQVAE(image_size, patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout, "vqvae", vq_dim).to(device)
input_tensor = torch.randn(1, in_channels, image_size, image_size).to(device)
output = model(input_tensor)
print(output.shape)

In [None]:
model

In [None]:
vit_total_params = sum(param.numel() for param in model.parameters())
vit_total_params

In [None]:
vit_total_params = sum(param.numel() for param in model.parameters())
vit_total_params

In [None]:
from torchvision import datasets, transforms

# Load the CIFAR-10 dataset
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.Resize(image_size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

In [None]:
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
# TODO
lr = 0.0006 # 0.003
weight_decay =  0 # 0.0001
num_epochs = 50 # 150
optimizer = torch.optim.Adam(model.parameters(),
                                lr=lr,
                                weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr, steps_per_epoch=len(trainloader), epochs=num_epochs)

In [None]:
cfg = {'epoch': num_epochs, 
           'lr':lr,
        'image_size': image_size,
        'patch_size': patch_size,
    'in_channels':in_channels,
'embed_dim':embed_dim,
'num_heads':num_heads,
'mlp_dim': mlp_dim,
'num_layers':num_layers,
'num_classes':num_classes,
'batch_size':batch_size,
'mode':mode,
'weight_decay':weight_decay,
'optimizer':'adam',
'scheduler':'onecyclelr'
          }

# sweep_configuration = {
#     "method": "random",
#     "name": "sweep",
#     "metric": {"goal": "maximize", "name": "val_acc"},
#     "parameters": {
#         "batch_size": {"values": [128, 64]},
#         "epochs": {"values": [25, 50]},
#         "lr": {"max": 0.01, "min": 0.0001},
#         "mode": {"values": ["patch", "vqvae"]}
#     },
# }
# sweep_id = wandb.sweep(sweep=sweep_configuration, project='bagel')
run = wandb.init(project="bagel", group='vitvq', config=cfg)
# run = wandb.init()

In [None]:
# wandb.agent(sweep_id)
from tqdm import tqdm

In [None]:
# Train the model
best_val_acc = 0
train_accs = []
test_accs = []
epochs_no_improve = 0
max_patience = 20
early_stop = False
pbar=tqdm(range(num_epochs))
for epoch in pbar:
    # if not load_pretrained:
    running_accuracy = 0.0
    running_loss = 0.0
    model.train()
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        scheduler.step()
        acc = (outputs.argmax(dim=1) == labels).float().mean()
        running_accuracy += acc / len(trainloader)
        running_loss += loss.item() / len(trainloader)
    
    train_accs.append(running_accuracy)

    wandb.log({'train_acc':running_accuracy, 'train_loss':loss})
    # TODO Feel free to modify the training loop youself.

    # Validate the model
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            val_loss = criterion(outputs, labels)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    val_acc = 100 * correct / total
    wandb.log({'val_acc':val_acc, 'val_loss':val_loss})
    pbar.set_postfix({"Epoch": epoch+1, "Train Accuracy": running_accuracy*100, "Training Loss": running_loss, "Validation Accuracy": val_acc})

    # Save the best model

    if val_acc > best_val_acc:
        epochs_no_improve = 0
        best_val_acc = val_acc
        torch.save({
            'epoch': epoch,
            'model': model.state_dict(),
            'optimizer': optimizer,
            'scheduler' : scheduler,
            'train_acc': train_accs,
            'test_acc': val_acc
        },  'best_model.pth')

    else:
        epochs_no_improve += 1

    if epoch > 100 and epochs_no_improve >= max_patience:
        print('Early stopping!')
        early_stop = True
        break
    else:
        continue

In [None]:
model