Custom Trained VGG16 from Scratch

In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm  # For progress bar

# Data Handling Class
class DataHandler:
    def __init__(self, train_dir, valid_dir, batch_size):
        self.train_dir = train_dir
        self.valid_dir = valid_dir
        self.batch_size = batch_size
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    def load_data(self):
        train_data = ImageFolder(root=self.train_dir, transform=self.transform)
        valid_data = ImageFolder(root=self.valid_dir, transform=self.transform)

        train_loader = DataLoader(dataset=train_data, batch_size=self.batch_size, shuffle=True)
        valid_loader = DataLoader(dataset=valid_data, batch_size=self.batch_size, shuffle=False)

        return train_loader, valid_loader, train_data.classes


class VGG16(nn.Module):
    def __init__(self, num_classes, dropout_rate=0.5):
        super(VGG16, self).__init__()
        self.features = self._make_layers()
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),  
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),
            nn.Linear(4096, num_classes),
        )

    def _make_layers(self):
        layers = []
        in_channels = 3  # For RGB images
        # Reducing the number of filters to prevent memory issues
        config = [32, 32, 'M', 64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M']
        
        for x in config:
            if x == 'M':
                layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
            else:
                layers.append(nn.Conv2d(in_channels, x, kernel_size=3, padding=1))
                layers.append(nn.ReLU(inplace=True))
                in_channels = x

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)  # Flattening
        x = self.classifier(x)
        return x

class Trainer:
    def __init__(self, model, train_loader, valid_loader, device, num_epochs=10, learning_rate=0.001, dropout_rate=0.5):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.device = device
        self.num_epochs = num_epochs
        self.learning_rate = learning_rate
        self.dropout_rate = dropout_rate
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
        self.writer = SummaryWriter()

    def train(self):
        for epoch in range(self.num_epochs):
            self.model.train()
            running_loss = 0.0
            correct = 0
            total = 0

            # Progress bar for training
            train_loader_tqdm = tqdm(self.train_loader, desc=f"Training Epoch {epoch+1}/{self.num_epochs}", leave=False)

            for inputs, labels in train_loader_tqdm:
                inputs, labels = inputs.to(self.device), labels.to(self.device)

                # Forward pass
                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                running_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                train_loader_tqdm.set_postfix(loss=loss.item())

            epoch_loss = running_loss / len(self.train_loader)
            epoch_acc = correct / total
            self.writer.add_scalar('Loss/train', epoch_loss, epoch)
            self.writer.add_scalar('Accuracy/train', epoch_acc, epoch)

            print(f'Epoch [{epoch+1}/{self.num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}')

            self.validate(epoch)

    def validate(self, epoch):
        self.model.eval()
        running_loss = 0.0
        correct = 0
        total = 0
        valid_loader_tqdm = tqdm(self.valid_loader, desc="Validating", leave=False)

        with torch.no_grad():
            for inputs, labels in valid_loader_tqdm:
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)

                running_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                # Update progress bar
                valid_loader_tqdm.set_postfix(loss=loss.item())

       
        epoch_loss = running_loss / len(self.valid_loader)
        epoch_acc = correct / total
        self.writer.add_scalar('Loss/valid', epoch_loss, epoch)
        self.writer.add_scalar('Accuracy/valid', epoch_acc, epoch)

        print(f'Validation Loss: {epoch_loss:.4f}, Validation Accuracy: {epoch_acc:.4f}')

    def save_checkpoint(self, path):
        torch.save(self.model.state_dict(), path)
        print(f"Model saved to {path}")

    def load_checkpoint(self, path):
        self.model.load_state_dict(torch.load(path))
        print(f"Model loaded from {path}")

    def close_writer(self):
        self.writer.close()

# Main function
def main():
    train_dir = 'data/data/train'
    valid_dir = 'data/data/valid'
    batch_size = 16  # Reduced batch size to 16 only so that we can prevent memory overflow
    num_epochs = 10
    learning_rate = 0.001
    dropout_rate = 0.5  

 
    data_handler = DataHandler(train_dir, valid_dir, batch_size)
    train_loader, valid_loader, classes = data_handler.load_data()

   
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

 
    model = VGG16(num_classes=len(classes), dropout_rate=dropout_rate)

    # Initialize trainer
    trainer = Trainer(model, train_loader, valid_loader, device, num_epochs, learning_rate, dropout_rate)
    
    
    trainer.train()

    
    trainer.save_checkpoint("vgg16_custom.pth")

 
    trainer.close_writer()

if __name__ == "__main__":
    main()


                                                                                                                       

Epoch [1/10], Loss: 2.0072, Accuracy: 0.2824


                                                                                                                       

Validation Loss: 2.0376, Validation Accuracy: 0.3333


                                                                                                                       

Epoch [2/10], Loss: 1.9854, Accuracy: 0.2882


                                                                                                                       

Validation Loss: 1.8170, Validation Accuracy: 0.3333


                                                                                                                       

Epoch [3/10], Loss: 1.8796, Accuracy: 0.2294


                                                                                                                       

Validation Loss: 1.8083, Validation Accuracy: 0.3333


                                                                                                                       

Epoch [4/10], Loss: 1.8823, Accuracy: 0.3059


                                                                                                                       

Validation Loss: 1.8268, Validation Accuracy: 0.3333


                                                                                                                       

Epoch [5/10], Loss: 1.8857, Accuracy: 0.3059


                                                                                                                       

Validation Loss: 1.8093, Validation Accuracy: 0.3333


                                                                                                                       

Epoch [6/10], Loss: 1.8837, Accuracy: 0.3059


                                                                                                                       

Validation Loss: 1.8166, Validation Accuracy: 0.3333


                                                                                                                       

Epoch [7/10], Loss: 1.8719, Accuracy: 0.2824


                                                                                                                       

Validation Loss: 1.8186, Validation Accuracy: 0.3333


                                                                                                                       

Epoch [8/10], Loss: 1.8616, Accuracy: 0.3000


                                                                                                                       

Validation Loss: 1.8144, Validation Accuracy: 0.3333


                                                                                                                       

Epoch [9/10], Loss: 1.8643, Accuracy: 0.3059


                                                                                                                       

Validation Loss: 1.8087, Validation Accuracy: 0.3333


                                                                                                                       

Epoch [10/10], Loss: 1.8601, Accuracy: 0.2941


                                                                                                                       

Validation Loss: 1.8309, Validation Accuracy: 0.2667
Model saved to vgg16_custom.pth


PreTrained VGG16 Network of Pytorch:

In [7]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import models
from tqdm import tqdm  # For progress bar

# Data Handling Class
class DataHandler:
    def __init__(self, train_dir, valid_dir, batch_size):
        self.train_dir = train_dir
        self.valid_dir = valid_dir
        self.batch_size = batch_size
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    def load_data(self):
        train_data = ImageFolder(root=self.train_dir, transform=self.transform)
        valid_data = ImageFolder(root=self.valid_dir, transform=self.transform)

        train_loader = DataLoader(dataset=train_data, batch_size=self.batch_size, shuffle=True)
        valid_loader = DataLoader(dataset=valid_data, batch_size=self.batch_size, shuffle=False)

        return train_loader, valid_loader, train_data.classes


class Trainer:
    def __init__(self, model, train_loader, valid_loader, device, num_epochs=10, learning_rate=0.001):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.device = device
        self.num_epochs = num_epochs
        self.learning_rate = learning_rate
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
        self.writer = SummaryWriter()

    def train(self):
        for epoch in range(self.num_epochs):
            self.model.train()
            running_loss = 0.0
            correct = 0
            total = 0

            # Progress bar for training
            train_loader_tqdm = tqdm(self.train_loader, desc=f"Training Epoch {epoch+1}/{self.num_epochs}", leave=False)

            for inputs, labels in train_loader_tqdm:
                inputs, labels = inputs.to(self.device), labels.to(self.device)

                # Forward pass
                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)

                # Backward pass and optimization
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                # Track accuracy and loss
                running_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                # Update progress bar
                train_loader_tqdm.set_postfix(loss=loss.item())
                
            epoch_loss = running_loss / len(self.train_loader)
            epoch_acc = correct / total
            self.writer.add_scalar('Loss/train', epoch_loss, epoch)
            self.writer.add_scalar('Accuracy/train', epoch_acc, epoch)

            print(f'Epoch [{epoch+1}/{self.num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}')

            # Validation at the end of each epoch
            self.validate(epoch)

    def validate(self, epoch):
        self.model.eval()
        running_loss = 0.0
        correct = 0
        total = 0

        # Progress bar for validation
        valid_loader_tqdm = tqdm(self.valid_loader, desc="Validating", leave=False)

        with torch.no_grad():
            for inputs, labels in valid_loader_tqdm:
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)

                running_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                # Update progress bar
                valid_loader_tqdm.set_postfix(loss=loss.item())

      
        epoch_loss = running_loss / len(self.valid_loader)
        epoch_acc = correct / total
        self.writer.add_scalar('Loss/valid', epoch_loss, epoch)
        self.writer.add_scalar('Accuracy/valid', epoch_acc, epoch)

        print(f'Validation Loss: {epoch_loss:.4f}, Validation Accuracy: {epoch_acc:.4f}')

    def save_checkpoint(self, path):
        torch.save(self.model.state_dict(), path)
        print(f"Model saved to {path}")

    def load_checkpoint(self, path):
        self.model.load_state_dict(torch.load(path))
        print(f"Model loaded from {path}")

    def close_writer(self):
        self.writer.close()

# Main function
def main():
    train_dir = 'data/data/train'
    valid_dir = 'data/data/valid'
    batch_size = 16  # Adjusted batch size
    num_epochs = 10
    learning_rate = 0.001

   
    data_handler = DataHandler(train_dir, valid_dir, batch_size)
    train_loader, valid_loader, classes = data_handler.load_data()

    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


    from torchvision.models import VGG16_Weights

    model = models.vgg16(weights=VGG16_Weights.DEFAULT)


    num_features = model.classifier[6].in_features  # Get the input features of the last layer
    model.classifier[6] = nn.Linear(num_features, len(classes))  # Adjust to your number of classes


   
    trainer = Trainer(model, train_loader, valid_loader, device, num_epochs, learning_rate)
    
    trainer.train()
    trainer.save_checkpoint("vgg16_pretrained_custom.pth")
    
    trainer.close_writer()

if __name__ == "__main__":
    main()


                                                                                                                       

Epoch [1/10], Loss: 3.1434, Accuracy: 0.2294


                                                                                                                       

Validation Loss: 1.8663, Validation Accuracy: 0.3333


                                                                                                                       

Epoch [2/10], Loss: 1.9699, Accuracy: 0.3882


                                                                                                                       

Validation Loss: 1.2796, Validation Accuracy: 0.6000


                                                                                                                       

Epoch [3/10], Loss: 1.4288, Accuracy: 0.5118


                                                                                                                       

Validation Loss: 1.3594, Validation Accuracy: 0.7333


                                                                                                                       

Epoch [4/10], Loss: 1.2995, Accuracy: 0.6000


                                                                                                                       

Validation Loss: 0.9446, Validation Accuracy: 0.7333


                                                                                                                       

Epoch [5/10], Loss: 1.1180, Accuracy: 0.5706


                                                                                                                       

Validation Loss: 0.8962, Validation Accuracy: 0.6667


                                                                                                                       

Epoch [6/10], Loss: 1.0204, Accuracy: 0.6353


                                                                                                                       

Validation Loss: 0.9693, Validation Accuracy: 0.7333


                                                                                                                       

Epoch [7/10], Loss: 0.9098, Accuracy: 0.6882


                                                                                                                       

Validation Loss: 0.6584, Validation Accuracy: 0.8000


                                                                                                                       

Epoch [8/10], Loss: 0.7836, Accuracy: 0.7176


                                                                                                                       

Validation Loss: 0.6674, Validation Accuracy: 0.8000


                                                                                                                       

Epoch [9/10], Loss: 0.6854, Accuracy: 0.7235


                                                                                                                       

Validation Loss: 0.5680, Validation Accuracy: 0.8000


                                                                                                                       

Epoch [10/10], Loss: 0.7208, Accuracy: 0.7294


                                                                                                                       

Validation Loss: 0.7305, Validation Accuracy: 0.7333
Model saved to vgg16_pretrained_custom.pth
