# Fine-tune an Existing Model
This notebook demonstrates how to train (or fine-tune) an existing model by loading pre-trained weights and continuing training.

In [None]:
# Import required libraries
import mvtech
import torchvision.utils as utils
import matplotlib.pyplot as plt
from torch.optim import Adam
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import os
import numpy as np
import pytorch_ssim
import mdn1
from VT_AE import VT_AE as ae

In [None]:
# Configuration parameters
config = {
    'product': 'hazelnut',
    'epochs': 400,
    'learning_rate': 0.0001,
    'patch_size': 64,
    'batch_size': 2
}
print("Configuration:")
for key, value in config.items():
    print(f'{key}: {value}')

In [None]:
# Initialize tensorboard writer and training variables
writer = SummaryWriter()
prdt = config['product']
epoch = config['epochs']
minloss = 1e10
ep = 0
ssim_loss = pytorch_ssim.SSIM()
print(f'Training setup complete for product: {prdt}')
print(f'Training for {epoch} epochs')

In [None]:
# Dataset
data = mvtech.Mvtec(config['batch_size'], product=prdt)
print(f'Dataset loaded for {prdt}')
print(f'Batch size: {config['batch_size']}')

In [None]:
# Model declaration
model = ae(patch_size=config['patch_size'], train=True).cuda()
G_estimate = mdn1.MDN().cuda()
model.train()
G_estimate.train()
print('Models initialized and moved to GPU')
print(f'Patch size: {config['patch_size']}')

In [None]:
# Load existing model weights if available
model_path = f'./saved_model/VT_AE_Mvtech_{prdt}.pt'
g_estimate_path = f'./saved_model/G_estimate_Mvtech_{prdt}.pt'
if os.path.exists(model_path) and os.path.exists(g_estimate_path):
    model.load_state_dict(torch.load(model_path))
    G_estimate.load_state_dict(torch.load(g_estimate_path))
    print('Loaded existing model weights.')
else:
    print('No saved weights found, training from scratch.')

In [None]:
# Print model architectures
print('--- Architecture of VT_AE ---')
print(model)
print('--- Architecture of MDN ---')
print(G_estimate)

In [None]:
# Optimizer Declaration
Optimiser = Adam(list(model.parameters()) + list(G_estimate.parameters()), lr=config['learning_rate'], weight_decay=0.0001)
print(f'Optimizer initialized with learning rate: {config['learning_rate']}')

In [None]:
# Training loop
print('Network training started.....')
for i in range(epoch):
    t_loss = []
    for j, m in data.train_loader:
        if j.size(1) == 1:
            j = torch.stack([j, j, j]).squeeze(2).permute(1, 0, 2, 3)
        model.zero_grad()
        vector, reconstructions = model(j.cuda())
        pi, mu, sigma = G_estimate(vector)
        loss1 = F.mse_loss(reconstructions, j.cuda(), reduction='mean')
        loss2 = -ssim_loss(j.cuda(), reconstructions)
        loss3 = mdn1.mdn_loss_function(vector, mu, sigma, pi)
        print(f'loss3: {loss3.item()}')
        loss = 5 * loss1 + 0.5 * loss2 + loss3
        t_loss.append(loss.item())
        writer.add_scalar('recon-loss', loss1.item(), i)
        writer.add_scalar('ssim loss', loss2.item(), i)
        writer.add_scalar('Gaussian loss', loss3.item(), i)
        writer.add_histogram('Vectors', vector)
        # writer.add_histogram('Pi', pi)
        # writer.add_histogram('Variance', sigma)
        # writer.add_histogram('Mean', mu)
        loss.backward()
        Optimiser.step()
    writer.add_image('Reconstructed Image', utils.make_grid(reconstructions), i, dataformats='CHW')
    writer.add_scalar('Mean Epoch loss', np.mean(t_loss), i)
    print(f'Mean Epoch {i} loss: {np.mean(t_loss)}')
    print(f'Min loss epoch: {ep} with min loss: {minloss}')
    if np.mean(t_loss) <= minloss:
        minloss = np.mean(t_loss)
        ep = i
        os.makedirs('./saved_model', exist_ok=True)
        torch.save(model.state_dict(), model_path)
        torch.save(G_estimate.state_dict(), g_estimate_path)
        print(f'New best model saved at epoch {i} with loss {minloss}')
writer.close()
print('Training completed!')

In [20]:
import torch
import mvtech
import torchvision.utils as utils
import matplotlib.pyplot as plt
from torch.optim import Adam
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import os
import numpy as np
import pytorch_ssim
import mdn1
from VT_AE import VT_AE as ae
import argparse

In [37]:
writer = SummaryWriter()

# Configuration parameters (modify these as needed)
config = {
    'product': 'bottle',  # product from the dataset MvTec or BTAD
    'epochs': 400,          # Number of epochs to train
    'learning_rate': 0.0001, # learning rate
    'patch_size': 64,       # Patch size of the images
    'batch_size': 2         # batch size
}

print("Configuration:")
for key, value in config.items():
    print(f"{key}: {value}")

Configuration:
product: bottle
epochs: 400
learning_rate: 0.0001
patch_size: 64
batch_size: 2


In [26]:
model = ae(train=False).cuda()
G_estimate = mdn1.MDN().cuda()

# Loading weights
model.load_state_dict(torch.load(f'./saved_model/VT_AE_Mvtech_hazelnut'+'.pt'))
G_estimate.load_state_dict(torch.load(f'./saved_model/G_estimate_Mvtech_hazelnut'+'.pt'))

<All keys matched successfully>

lemon trees

In [27]:
print("\n--- Architecture of whatever model this is ---\n")
print(model)

print("\n\n\n\n\n\n--- Architecture of mixxed density network ---\n")
print(G_estimate)


--- Architecture of whatever model this is ---

VT_AE(
  (vt): ViT(
    (patch_to_embedding): Linear(in_features=12288, out_features=512, bias=True)
    (transformer): Transformer(
      (layers): ModuleList(
        (0-5): 6 x ModuleList(
          (0): Residual(
            (fn): PreNorm(
              (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
              (fn): Attention(
                (to_qkv): Linear(in_features=512, out_features=1536, bias=False)
                (to_out): Sequential(
                  (0): Linear(in_features=512, out_features=512, bias=True)
                )
              )
            )
          )
          (1): Residual(
            (fn): PreNorm(
              (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
              (fn): FeedForward(
                (net): Sequential(
                  (0): Linear(in_features=512, out_features=1024, bias=True)
                  (1): GELU(approximate='none')
                  (2): L

In [29]:
# Dataset
data = mvtech.Mvtec(config["batch_size"], product=prdt)

print(f"Dataset loaded for {config['product']}")
print(f"Batch size: {config['batch_size']}")

total train images of good bottle are: 209
total test images of broken_large bottle are: 20
total test images of broken_small bottle are: 22
total test images of contamination bottle are: 21
the good images for test images of good bottle is not included in the test anomolous data
total ground_truth images of broken_large bottle are: 20
total ground_truth images of broken_small bottle are: 22
total ground_truth images of contamination bottle are: 21
total test images of good bottle are: 20
 --Size of bottle train loader: torch.Size([209, 3, 512, 512])--
 --Size of bottle test anomaly loader: torch.Size([63, 3, 512, 512])--
 --Size of bottle test normal loader: torch.Size([20, 3, 512, 512])--
 --Total Image in bottle Validation loader: 20--
Dataset loaded for bottle
Batch size: 2


In [30]:
# Optimizer Declaration
Optimiser = Adam(
    list(model.parameters()) + list(G_estimate.parameters()), 
    lr=config["learning_rate"], 
    weight_decay=0.0001
)

print(f"Optimizer initialized with learning rate: {config['learning_rate']}")

Optimizer initialized with learning rate: 0.0001


In [34]:
def freeze_vit_except_lora(model):
    for name, param in model.named_parameters():
        param.requires_grad = False
    for module in model.modules():
        if hasattr(module, 'lora_A') and module.lora_A is not None:
            module.lora_A.requires_grad = True
        if hasattr(module, 'lora_B') and module.lora_B is not None:
            module.lora_B.requires_grad = True

freeze_vit_except_lora(model)

In [38]:
# Uncomment the line below if you want to track errors
# torch.autograd.set_detect_anomaly(True)

print('\nNetwork training started.....')

for i in range(config['epochs']):
    t_loss = []
    
    for j, m in data.train_loader:
        # Handle grayscale images by converting to 3-channel
        if j.size(1) == 1:
            j = torch.stack([j, j, j]).squeeze(2).permute(1, 0, 2, 3)
        
        model.zero_grad()
        
        # Forward pass
        vector, reconstructions = model(j.cuda())
        pi, mu, sigma = G_estimate(vector)
        
        # Loss calculations
        loss1 = F.mse_loss(reconstructions, j.cuda(), reduction='mean')  # Reconstruction Loss
        loss2 = -ssim_loss(j.cuda(), reconstructions)  # SSIM loss for structural similarity
        loss3 = mdn1.mdn_loss_function(vector, mu, sigma, pi)  # MDN loss for gaussian approximation
        
        print(f'loss3: {loss3.item()}')
        loss = 5 * loss1 + 0.5 * loss2 + loss3  # Total loss
        
        t_loss.append(loss.item())  # storing all batch losses to calculate mean epoch loss
        
        # Tensorboard logging
        writer.add_scalar('recon-loss', loss1.item(), i)
        writer.add_scalar('ssim loss', loss2.item(), i)
        writer.add_scalar('Gaussian loss', loss3.item(), i)
        writer.add_histogram('Vectors', vector)
        
        ## Uncomment below to store the distributions of pi, var and mean ##        
        # writer.add_histogram('Pi', pi)
        # writer.add_histogram('Variance', sigma)
        # writer.add_histogram('Mean', mu)

        # Backward pass and optimization
        loss.backward()
        Optimiser.step()
    
    # Log epoch-level information
    writer.add_image('Reconstructed Image', utils.make_grid(reconstructions), i, dataformats='CHW')
    writer.add_scalar('Mean Epoch loss', np.mean(t_loss), i)
    
    print(f'Mean Epoch {i} loss: {np.mean(t_loss)}')
    print(f'Min loss epoch: {ep} with min loss: {minloss}')
    
    # Save the best model
    if np.mean(t_loss) <= minloss:
        minloss = np.mean(t_loss)
        ep = i
        os.makedirs('./saved_model', exist_ok=True)
        torch.save(model.state_dict(), f'./saved_model/VT_AE_Mvtech_{prdt}_peft.pt')
        torch.save(G_estimate.state_dict(), f'./saved_model/G_estimate_Mvtech_{prdt}_peft.pt')
        print(f"New best model saved at epoch {i} with loss {minloss}")

writer.close()
print("Training completed!")


Network training started.....
loss3: 1377765.875
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss3: nan
loss

NameError: name 'ep' is not defined

# Fine-tune Existing Model
This notebook demonstrates how to train (or continue training) an existing model by loading its weights and running the training loop.

In [None]:
# Import required libraries
import mvtech
import torchvision.utils as utils
import matplotlib.pyplot as plt
from torch.optim import Adam
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import os
import numpy as np
import pytorch_ssim
import mdn1
from VT_AE import VT_AE as ae
import torch