In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import shutil
import importlib
import scripts.preprocessing as preprocessing
importlib.reload(preprocessing)
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
import torchsummary
import torch.optim as optim
import scripts.models as models
import tqdm
import mlflow
import mlflow.pytorch
import os
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image



importlib.reload(models)

<module 'scripts.models' from '/home/ronin/Dev/notebooks/machinelearningformodeling/supervised/project/scripts/models.py'>

In [8]:
model = models.ColorizationSqueezeNet()

In [9]:
torchsummary.summary(model, (1, 224, 224));

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 512, 14, 14]         --
|    └─Conv2d: 2-1                       [-1, 64, 112, 112]        640
|    └─ReLU: 2-2                         [-1, 64, 112, 112]        --
|    └─MaxPool2d: 2-3                    [-1, 64, 56, 56]          --
|    └─Fire: 2-4                         [-1, 128, 56, 56]         --
|    |    └─Conv2d: 3-1                  [-1, 16, 56, 56]          1,040
|    |    └─ReLU: 3-2                    [-1, 16, 56, 56]          --
|    |    └─Conv2d: 3-3                  [-1, 64, 56, 56]          1,088
|    |    └─ReLU: 3-4                    [-1, 64, 56, 56]          --
|    |    └─Conv2d: 3-5                  [-1, 64, 56, 56]          9,280
|    |    └─ReLU: 3-6                    [-1, 64, 56, 56]          --
|    └─Fire: 2-5                         [-1, 128, 56, 56]         --
|    |    └─Conv2d: 3-7                  [-1, 16, 56, 56]          2,064
| 

In [10]:
class ImageDataset(Dataset):
    def __init__(self, image_folder, transform=None):
        self.image_folder = image_folder
        self.image_files = [f for f in os.listdir(image_folder) if os.path.isfile(os.path.join(image_folder, f))]
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_folder, self.image_files[idx])
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        grayscale_image = transforms.functional.rgb_to_grayscale(image, num_output_channels=1)
        
        return grayscale_image, image

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

dataset = ImageDataset(image_folder='SSL/images', transform=transform)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)


In [12]:
# Initialize model, loss function, and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [14]:
from tqdm import tqdm
import mlflow
import mlflow.pytorch

def train(model, dataloader, criterion, optimizer, num_epochs=20):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    with mlflow.start_run():
        mlflow.log_param("num_epochs", num_epochs)
        mlflow.log_param("learning_rate", optimizer.param_groups[0]['lr'])
        mlflow.log_param("batch_size", dataloader.batch_size)
        
        for epoch in range(num_epochs):
            model.train()
            running_loss = 0.0

            # Use tqdm for the progress bar
            progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")

            for grayscale, color in progress_bar:
                grayscale = grayscale.to(device)
                color = color.to(device)

                optimizer.zero_grad()
                outputs = model(grayscale)
                loss = criterion(outputs, color)
                loss.backward()
                optimizer.step()

                running_loss += loss.item() * grayscale.size(0)

                # Update progress bar with the current loss
                progress_bar.set_postfix(loss=loss.item())

            epoch_loss = running_loss / len(dataloader.dataset)
            print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
            mlflow.log_metric("loss", epoch_loss, step=epoch)

        # Log the model at the end of the run
        mlflow.pytorch.log_model(model, "colorization_model")


In [15]:
# Start training
train(model, dataloader, criterion, optimizer, num_epochs=5)

Epoch 1/5: 100%|██████████| 3576/3576 [50:02<00:00,  1.19it/s, loss=0.0124] 


Epoch 1/5, Loss: 0.0154


Epoch 2/5: 100%|██████████| 3576/3576 [47:20<00:00,  1.26it/s, loss=0.0135] 


Epoch 2/5, Loss: 0.0126


Epoch 3/5: 100%|██████████| 3576/3576 [47:12<00:00,  1.26it/s, loss=0.0126] 


Epoch 3/5, Loss: 0.0119


Epoch 4/5: 100%|██████████| 3576/3576 [47:14<00:00,  1.26it/s, loss=0.0109] 


Epoch 4/5, Loss: 0.0115


Epoch 5/5: 100%|██████████| 3576/3576 [47:11<00:00,  1.26it/s, loss=0.0134] 


Epoch 5/5, Loss: 0.0111


In [16]:

# save model to disk
torch.save(model.state_dict(), 'modelSSL.pth')

In [34]:
# load mdel from disk
model = models.ColorizationSqueezeNet()
model.load_state_dict(torch.load('modelSSL.pth'))

<All keys matched successfully>

In [35]:
model.features[0]

Conv2d(1, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))

In [6]:
# load netFromSSL from disk
# Load the original state dictionary
state_dict = torch.load('modelSSL.pth')


In [7]:
# Modify the first layer's weights to handle three channels
original_weights = state_dict['features.0.weight']  # Shape: [64, 1, 3, 3]

# Create a new weight tensor with three channels
new_weights = torch.zeros((original_weights.size(0), 
                           3, 
                           original_weights.size(2), 
                           original_weights.size(3)))

# Copy the original weights to each of the three channels
new_weights[:, 0:1, :, :] = original_weights
new_weights[:, 1:2, :, :] = original_weights
new_weights[:, 2:3, :, :] = original_weights

# Replace the weights in the state dictionary
state_dict['features.0.weight'] = new_weights

In [14]:
netFromSSL = models.SqueezeNet()

# Extract the `features` part of the state dictionary and remove the "features." prefix
features_state_dict = {k.replace('features.', ''): v for k, v in state_dict.items() if k.startswith('features.')}

# Create the SqueezeNet model with three input channels
netFromSSL = models.SqueezeNet()

# Load the `features` part of the state dictionary into the new model
netFromSSL.features.load_state_dict(features_state_dict)



<All keys matched successfully>

In [16]:
torch.save(netFromSSL.state_dict(), 'netFromSSL.pth')