In [4]:
from transformers import AutoImageProcessor, SwinForMaskedImageModeling

# Load model and processor
processor_path = "processor"  # Replace with actual path
model_path = "model"          # Replace with actual path

image_processor = AutoImageProcessor.from_pretrained("microsoft/swin-base-simmim-window6-192")
model = SwinForMaskedImageModeling.from_pretrained("microsoft/swin-base-simmim-window6-192")
model.save_pretrained(model_path)
image_processor.save_pretrained(processor_path)

['processor/preprocessor_config.json']

In [2]:
import torch
from torch.utils.data import Dataset, DataLoader, Subset
from PIL import Image
import os
import numpy as np
from sklearn.model_selection import train_test_split


class MaskedImageDataset(Dataset):
    def __init__(self, original_dir, masked_dirs, image_size):
        self.original_dir = original_dir
        self.masked_dirs = masked_dirs
        self.image_size = image_size

        self.original_images = sorted(os.listdir(original_dir))
        self.masked_images = []

        # Collect images from all masked directories
        for masked_dir in masked_dirs:
            masked_images = sorted(os.listdir(masked_dir))
            self.masked_images.extend([(masked_dir, img) for img in masked_images])

        # Filter out invalid images
        self.original_images = [img for img in self.original_images if self.is_valid_image(os.path.join(original_dir, img))]
        self.masked_images = [entry for entry in self.masked_images if self.is_valid_image(os.path.join(entry[0], entry[1]))]

    def __len__(self):
        return len(self.original_images)

    def __getitem__(self, idx):
        original_image_name = self.original_images[idx]
        masked_dir, masked_image_name = self.masked_images[idx]

        original_image = Image.open(os.path.join(self.original_dir, original_image_name)).convert("RGB")
        masked_image = Image.open(os.path.join(masked_dir, masked_image_name)).convert("RGB")

        original_image = original_image.resize((self.image_size, self.image_size))
        masked_image = masked_image.resize((self.image_size, self.image_size))

        original_image_tensor = torch.tensor(np.array(original_image).astype(np.float32) / 255.0).permute(2, 0, 1)
        masked_image_tensor = torch.tensor(np.array(masked_image).astype(np.float32) / 255.0).permute(2, 0, 1)

        return masked_image_tensor, original_image_tensor

    @staticmethod
    def is_valid_image(filepath):
        try:
            with Image.open(filepath) as img:
                img.verify()  # Check if it is a valid image
            return True
        except (IOError, SyntaxError):
            return False


def split_dataset_per_breed(original_dir, masked_root_dir, image_size, train_ratio=0.8):
    breed_name = os.path.basename(original_dir)
    masked_dirs = [os.path.join(masked_root_dir, f"{breed_name}_{color}") for color in ["black", "blue", "red", "white", "yellow"]]
    
    dataset = MaskedImageDataset(original_dir, masked_dirs, image_size)
    
    # Calculate sizes for splits
    total_size = len(dataset)
    train_size = int(train_ratio * total_size)
    test_size = total_size - train_size

    indices = list(range(total_size))
    train_indices, test_indices = train_test_split(indices, train_size=train_size, random_state=42)
    
    train_subset = Subset(dataset, train_indices)
    test_subset = Subset(dataset, test_indices)
    
    return train_subset, test_subset


# Directories containing images
original_dirs = [
    "data/dataset/Beagle",
    "data/dataset/Boxer",
    "data/dataset/Bulldog",
    "data/dataset/Dachshund",
    "data/dataset/German_Shepherd"
]
masked_root_dir = "new_dataset_5050"
image_size = 192

# Splitting datasets for each breed
train_datasets = []
test_datasets = []

for original_dir in original_dirs:
    train_subset, test_subset = split_dataset_per_breed(original_dir, masked_root_dir, image_size)
    train_datasets.append(train_subset)
    test_datasets.append(test_subset)

# Combine datasets from all breeds
train_dataset = torch.utils.data.ConcatDataset(train_datasets)
test_dataset = torch.utils.data.ConcatDataset(test_datasets)

# Initialize DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=False)

In [32]:
len(test_datasets)

5

In [5]:
from transformers import AutoImageProcessor, SwinForMaskedImageModeling

# Load model and processor
processor_path = "processor"  # Replace with actual path
model_path = "model"          # Replace with actual path

image_processor = AutoImageProcessor.from_pretrained(processor_path)
model = SwinForMaskedImageModeling.from_pretrained(model_path, ignore_mismatched_sizes=True)


In [6]:
import torch.nn as nn
import torch.optim as optim

# nn.L1Loss()
# nn.MSELoss()
criterion = nn.L1Loss()  # Mean Squared Error loss for reconstruction
optimizer = optim.AdamW(model.parameters(), lr=1e-4)


In [7]:
import torch
from torch.optim.lr_scheduler import CosineAnnealingLR

# Initialize optimizer with weight decay
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)

# Cosine learning rate scheduler
scheduler = CosineAnnealingLR(optimizer, T_max=10)  # 10 epochs

num_epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0

    for masked_images, original_images in train_dataloader:
        masked_images = masked_images.to(device)
        original_images = original_images.to(device)

        # Generate mask positions
        batch_size, _, height, width = masked_images.shape
        patch_size = model.config.patch_size
        num_patches_per_row = height // patch_size
        num_patches = num_patches_per_row ** 2
        bool_masked_pos = torch.zeros((batch_size, num_patches), dtype=torch.bool).to(device)

        # Example: Random masking (adjust as needed)
        for i in range(batch_size):
            mask_indices = torch.randint(0, num_patches, (num_patches // 10,))  # Mask 10% patches
            bool_masked_pos[i, mask_indices] = True

        # Forward pass
        outputs = model(pixel_values=masked_images, bool_masked_pos=bool_masked_pos)
        loss = criterion(outputs.reconstruction, original_images)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # Step the scheduler
    scheduler.step()
    
    avg_loss = total_loss / len(train_dataloader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, LR: {scheduler.get_last_lr()[0]:.6f}")

    # Save model after each epoch
    model.save_pretrained(f"fine_tuned_model_epoch_{epoch+1}")


Epoch [1/10], Loss: 0.1099, LR: 0.000098
Epoch [2/10], Loss: 0.0659, LR: 0.000090
Epoch [3/10], Loss: 0.0590, LR: 0.000079
Epoch [4/10], Loss: 0.0520, LR: 0.000065
Epoch [5/10], Loss: 0.0474, LR: 0.000050
Epoch [6/10], Loss: 0.0447, LR: 0.000035
Epoch [7/10], Loss: 0.0431, LR: 0.000021
Epoch [8/10], Loss: 0.0422, LR: 0.000010
Epoch [9/10], Loss: 0.0406, LR: 0.000002
Epoch [10/10], Loss: 0.0402, LR: 0.000000


In [8]:
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

def evaluate_and_reconstruct(model, dataloader, save_reconstructions=False, output_dir="reconstructed_images"):
    model.eval()
    mse = 0.0
    total_samples = 0
    reconstructed_images_list = []

    with torch.no_grad():
        for batch_idx, (masked_images, original_images) in enumerate(dataloader):
            masked_images = masked_images.to(device)
            original_images = original_images.to(device)

            # Generate mask positions (same as training)
            batch_size, _, height, width = masked_images.shape
            patch_size = model.config.patch_size
            num_patches_per_row = height // patch_size
            num_patches = num_patches_per_row ** 2
            bool_masked_pos = torch.zeros((batch_size, num_patches), dtype=torch.bool).to(device)

            outputs = model(pixel_values=masked_images, bool_masked_pos=bool_masked_pos)
            reconstructed_images = outputs.reconstruction

            # Calculate loss
            mse += nn.functional.l1_loss(reconstructed_images, original_images, reduction="sum").item()
            total_samples += batch_size

            # Convert reconstructed images to numpy for saving/visualization
            reconstructed_images_np = (
                reconstructed_images.cpu().permute(0, 2, 3, 1).numpy()
            )  # Shape: (batch_size, height, width, channels)
            reconstructed_images_np = np.clip(reconstructed_images_np, 0, 1)  # Ensure values are in [0, 1]

            # Save or visualize
            if save_reconstructions:
                for i in range(batch_size):
                    img = (reconstructed_images_np[i] * 255).astype(np.uint8)  # Convert to [0, 255]
                    Image.fromarray(img).save(f"{output_dir}/reconstructed_{batch_idx * batch_size + i}.png")
                   


            reconstructed_images_list.append(reconstructed_images_np)

    mse = mse / total_samples  # Average MSE
    return mse, np.concatenate(reconstructed_images_list, axis=0)


In [9]:
# ["black", "blue", "red", "white", "yellow"]
for i in range(len(test_datasets)):
    test_dataloader = DataLoader(test_datasets[i], batch_size=8, shuffle=False)
    val_loss, reconstructed_images = evaluate_and_reconstruct(
        model,
        test_dataloader,
        save_reconstructions=True,
        output_dir="./reconstructed_images_1"
    )
    print(f"Validation L1 loss: {val_loss:.4f}")


Validation L1 loss: 3799.1011
Validation L1 loss: 3806.2742
Validation L1 loss: 3586.1477
Validation L1 loss: 4090.8860
Validation L1 loss: 4058.1747


In [12]:
sum(p.numel() for p in model.parameters())

89874104