In [1]:
pip install pytorch-msssim


Collecting pytorch-msssim
  Downloading pytorch_msssim-1.0.0-py3-none-any.whl.metadata (8.0 kB)
Downloading pytorch_msssim-1.0.0-py3-none-any.whl (7.7 kB)
Installing collected packages: pytorch-msssim
Successfully installed pytorch-msssim-1.0.0


In [28]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from pytorch_msssim import ssim
from torchvision import datasets, transforms
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import tarfile
import os
import shutil
import requests
import random
import json
from tqdm import tqdm




In [None]:
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

In [3]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv_layers = nn.ModuleList([
            nn.Conv2d(in_channels=3,out_channels = 32, kernel_size=3, padding=1),
            nn.Conv2d(in_channels= 32,out_channels = 64, kernel_size=3, padding=1),
            nn.Conv2d(in_channels= 64, out_channels=128, kernel_size=3, padding=1),
            nn.Conv2d(in_channels= 128, out_channels=256, kernel_size=3, padding=1)
        ])

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.bottleneck = nn.Conv2d(in_channels= 256,out_channels= 64, kernel_size=3, padding=1)

    def forward(self, x):
        for conv in self.conv_layers:
            x = F.relu(conv(x))
            x = self.pool(x)

        x = F.relu(self.bottleneck(x))
        return x





In [4]:

class Decoder(nn.Module):
    def __init__(self):
        super().__init__()

        self.deconv_layers = nn.ModuleList([
            nn.ConvTranspose2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1),
            nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1),
            nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1),
            nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=4, stride=2, padding=1),
        ])

        self.final_layer = nn.ConvTranspose2d(in_channels=32, out_channels = 3, kernel_size=4, stride=2, padding=1)

    def forward(self, x):
        x = F.relu(self.deconv_layers[0](x))
        for deconv in self.deconv_layers[1:]:
            x = F.relu(deconv(x))
        x = torch.sigmoid(self.final_layer(x))
        return x

In [5]:
encoder = Encoder()
decoder = Decoder()


In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder.to(device)
decoder.to(device)

Decoder(
  (deconv_layers): ModuleList(
    (0): ConvTranspose2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (2): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  )
  (final_layer): ConvTranspose2d(32, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
)

In [7]:


url = "https://data.caltech.edu/records/nyy15-4j048/files/256_ObjectCategories.tar?download=1"
dataset_path = "256_ObjectCategories.tar"

print(f"Downloading dataset from {url}...")
response = requests.get(url, stream=True)
response.raise_for_status()  # Raise an exception for bad status codes

with open(dataset_path, 'wb') as f:
    for chunk in response.iter_content(chunk_size=8192):
        f.write(chunk)

print("Download complete.")

Downloading dataset from https://data.caltech.edu/records/nyy15-4j048/files/256_ObjectCategories.tar?download=1...
Download complete.


In [8]:

dataset_path = "256_ObjectCategories.tar"
extract_dir = "256_ObjectCategories"

if os.path.exists(dataset_path):
    print(f"Extracting dataset to {extract_dir}...")
    with tarfile.open(dataset_path, "r") as tar:
        tar.extractall(path=extract_dir)
    print("Extraction complete.")
else:
    print(f"Error: The file '{dataset_path}' was not found.")

Extracting dataset to 256_ObjectCategories...


  tar.extractall(path=extract_dir)


Extraction complete.


In [9]:


source_dir = "256_ObjectCategories"
target_dir = "dataset"

# Create the target directory if it doesn't exist
os.makedirs(target_dir, exist_ok=True)

print(f"Consolidating images from '{source_dir}' into '{target_dir}'...")

# Iterate through all directories and files in the source directory
for root, dirs, files in os.walk(source_dir, topdown=False): # topdown=False to process subdirectories before their parent
    for file in files:
        # Construct the full path to the file
        file_path = os.path.join(root, file)
        # Construct the new path for the file in the target directory
        new_file_path = os.path.join(target_dir, file)

        # Ensure unique filenames if necessary (optional, depending on dataset)
        # If filenames might not be unique, you might want to rename them
        # For this dataset, it seems filenames within a category are unique,
        # but filenames across categories might not be. Let's rename them
        # to include the category name to be safe.

        # Get the category name from the root path
        category_name = os.path.basename(root)
        if category_name and category_name != source_dir: # Avoid using the base source dir as category
             new_file_name = f"{category_name}_{file}"
             new_file_path = os.path.join(target_dir, new_file_name)


        try:
            # Move the file
            shutil.move(file_path, new_file_path)
            # print(f"Moved: {file_path} -> {new_file_path}") # Optional: uncomment to see file movements
        except Exception as e:
            print(f"Error moving file {file_path}: {e}")

    # After moving files, remove the subdirectory if it's empty
    if root != source_dir and not os.listdir(root):
        try:
            os.rmdir(root)
            # print(f"Removed empty directory: {root}") # Optional: uncomment to see removed directories
        except OSError as e:
            print(f"Error removing directory {root}: {e}")


# Optional: remove the original source directory if all files/subdirs were moved
# Be cautious with this step, only uncomment if you are sure all needed files are moved.
# if not os.listdir(source_dir):
#     try:
#         os.rmdir(source_dir)
#         print(f"Removed original source directory: {source_dir}")
#     except OSError as e:
#          print(f"Error removing source directory {source_dir}: {e}")


print("Image consolidation complete.")

Consolidating images from '256_ObjectCategories' into 'dataset'...
Image consolidation complete.


In [10]:



dataset_dir = '/content/dataset'
train_dir = '/content/train'
val_dir = '/content/val'

os.makedirs(train_dir, exist_ok=True)
os.makedirs(val_dir, exist_ok=True)

# Get all image files
all_files = [f for f in os.listdir(dataset_dir) if os.path.isfile(os.path.join(dataset_dir, f))]

# Shuffle randomly
random.shuffle(all_files)

# Split 80% train, 20% val
split_idx = int(len(all_files) * 0.8)
train_files = all_files[:split_idx]
val_files = all_files[split_idx:]

# Move files
for f in train_files:
    shutil.move(os.path.join(dataset_dir, f), os.path.join(train_dir, f))

for f in val_files:
    shutil.move(os.path.join(dataset_dir, f), os.path.join(val_dir, f))

print(f"Total images: {len(all_files)}")
print(f"Training images: {len(train_files)}")
print(f"Validation images: {len(val_files)}")


Total images: 30608
Training images: 24486
Validation images: 6122


In [14]:
from PIL import Image
import os
from torchvision import transforms

transform1 = transforms.Resize((256, 256))  # returns PIL Image

def process_and_save_images(src_dir):
    i=0
    for fname in os.listdir(src_dir):
      try:
          i+=1
          fpath = os.path.join(src_dir, fname)
          if os.path.isfile(fpath):
              img = Image.open(fpath).convert('RGB')  # ensure 3 channels
              img_resized = transform1(img)  # Resize
              img_resized.save(fpath)        # Save resized image
              if i % 100 == 0:
                  print(f"Processed {i} images")
      except Exception as e:
          os.remove(fpath)
          print(f"Error processing {fpath}: {e}")




# Process train and val
process_and_save_images(train_dir)
process_and_save_images(val_dir)


Processed 100 images
Processed 200 images
Processed 300 images
Processed 400 images
Processed 500 images
Processed 600 images
Processed 700 images
Processed 800 images
Processed 900 images
Processed 1000 images
Processed 1100 images
Processed 1200 images
Processed 1300 images
Processed 1400 images
Processed 1500 images
Processed 1600 images
Processed 1700 images
Processed 1800 images
Processed 1900 images
Processed 2000 images
Processed 2100 images
Processed 2200 images
Processed 2300 images
Processed 2400 images
Processed 2500 images
Processed 2600 images
Processed 2700 images
Processed 2800 images
Processed 2900 images
Processed 3000 images
Processed 3100 images
Processed 3200 images
Processed 3300 images
Processed 3400 images
Processed 3500 images
Processed 3600 images
Processed 3700 images
Processed 3800 images
Processed 3900 images
Processed 4000 images
Processed 4100 images
Processed 4200 images
Processed 4300 images
Processed 4400 images
Processed 4500 images
Processed 4600 imag

In [15]:


transform = transforms.Compose([
    transforms.ToTensor()
])

class ImageFolderNoClass(Dataset):
    def __init__(self, folder_path, transform=None):
        self.folder_path = folder_path
        self.transform = transform
        self.image_files = [os.path.join(folder_path, f)
                            for f in os.listdir(folder_path)
                            if os.path.isfile(os.path.join(folder_path, f))]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        img = Image.open(img_path).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img

# Paths
train_dir = '/content/train'
val_dir = '/content/val'

# Create datasets
train_dataset = ImageFolderNoClass(train_dir, transform=transform)
val_dataset = ImageFolderNoClass(val_dir, transform=transform)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

for batch in train_loader:
    print(batch.shape)
    break


torch.Size([32, 3, 256, 256])


In [None]:

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Directories
os.makedirs('./checkpoints', exist_ok=True)
loss_file = './checkpoints/losses.json'
model_path = './checkpoints/model.pth'

# Initialize models
encoder = Encoder().to(device)
decoder = Decoder().to(device)

# Optimizer & loss
optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=0.001)
l1_criterion = nn.L1Loss()

# Load checkpoint if exists
start_epoch = 0
losses_dict = {
    'train_l1': [], 'train_ssim': [], 'train_total': [],
    'val_l1': [], 'val_ssim': [], 'val_total': []
}

if os.path.exists(model_path):
    print(f"Loading checkpoint from {model_path}...")
    checkpoint = torch.load(model_path, map_location=device)
    encoder.load_state_dict(checkpoint['encoder_state_dict'])
    decoder.load_state_dict(checkpoint['decoder_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']
    losses_dict = checkpoint.get('losses', losses_dict)
    print(f"Checkpoint loaded. Resuming from epoch {start_epoch}")
else:
    print("No checkpoint found. Starting fresh training.")

# Training loop
num_epochs = 20
for epoch in range(start_epoch, start_epoch + num_epochs):
    encoder.train()
    decoder.train()
    running_l1, running_ssim, running_total = 0.0, 0.0, 0.0

    for images in tqdm(train_loader, desc=f"Epoch {epoch+1}/{start_epoch + num_epochs} [Training]"):
        images = images.to(device)

        optimizer.zero_grad()
        latent = encoder(images)
        outputs = decoder(latent)

        l1_loss_val = l1_criterion(outputs, images)
        ssim_val = ssim(outputs, images, data_range=1.0, size_average = True)
        total_loss = 0.8 * l1_loss_val + 0.2 * (1 - ssim_val)

        total_loss.backward()
        optimizer.step()

        running_l1 += l1_loss_val.item()
        running_ssim += ssim_val.item()
        running_total += total_loss.item()

    # Training loss averages
    train_l1_avg = running_l1 / len(train_loader)
    train_ssim_avg = running_ssim / len(train_loader)
    train_total_avg = running_total / len(train_loader)

    losses_dict['train_l1'].append(train_l1_avg)
    losses_dict['train_ssim'].append(train_ssim_avg)
    losses_dict['train_total'].append(train_total_avg)

    # Validation
    encoder.eval()
    decoder.eval()
    val_l1_total, val_ssim_total, val_total_total = 0.0, 0.0, 0.0

    with torch.no_grad():
        for images in tqdm(val_loader, desc=f"Epoch {epoch+1}/{start_epoch + num_epochs} [Validation]"):
            images = images.to(device)
            latent = encoder(images)
            outputs = decoder(latent)

            l1_loss_val = l1_criterion(outputs, images)
            ssim_val = ssim(outputs, images, data_range=1.0, size_average = True)
            total_loss = 0.8 * l1_loss_val + 0.2 * (1 - ssim_val)

            val_l1_total += l1_loss_val.item()
            val_ssim_total += ssim_val.item()
            val_total_total += total_loss.item()

    val_l1_avg = val_l1_total / len(val_loader)
    val_ssim_avg = val_ssim_total / len(val_loader)
    val_total_avg = val_total_total / len(val_loader)

    losses_dict['val_l1'].append(val_l1_avg)
    losses_dict['val_ssim'].append(val_ssim_avg)
    losses_dict['val_total'].append(val_total_avg)

    print(f"Epoch [{epoch+1}/{start_epoch + num_epochs}] "
          f"Train Loss: {train_total_avg:.4f} (L1: {train_l1_avg:.4f}, SSIM: {train_ssim_avg:.4f}) "
          f"Val Loss: {val_total_avg:.4f} (L1: {val_l1_avg:.4f}, SSIM: {val_ssim_avg:.4f})")

    # checkpointing
    torch.save({
        'encoder_state_dict': encoder.state_dict(),
        'decoder_state_dict': decoder.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch + 1,
        'losses': losses_dict
    }, model_path)



    # Save JSON
    with open(loss_file, 'w') as f:
        json.dump(losses_dict, f, indent=4)


Using device: cuda
No checkpoint found. Starting fresh training.


Epoch 1/20 [Training]: 100%|██████████| 766/766 [02:30<00:00,  5.11it/s]
Epoch 1/20 [Validation]: 100%|██████████| 192/192 [00:20<00:00,  9.33it/s]


Epoch [1/20] Train Loss: 0.1785 (L1: 0.1099, SSIM: 0.5469) Val Loss: 0.1386 (L1: 0.0755, SSIM: 0.6089)


Epoch 2/20 [Training]:  51%|█████     | 392/766 [01:15<01:10,  5.32it/s]