In [1]:
import os
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from utils_cells import get_images_list, transform_image, transform_target, resize_with_padding
import matplotlib.pyplot as plt
from sklearn.utils import shuffle
import numpy as np
import torchvision.transforms.functional as F
import torch
from torchvision import transforms
from torchvision.transforms import functional as F
import cv2
from sklearn.model_selection import train_test_split
import optuna
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet18
from torchmetrics import Precision, Recall
import numpy as np
import datetime
import random
import time
import torchvision.models as models

import random

class ImageDataset(Dataset):
    def __init__(self, data_path, transform=None, target_transform=None, reduce=False):
        self.transform = transform
        self.target_transform = target_transform
        self.dataset = shuffle(self.load_dataset(data_path))

    def load_dataset(self, path):
        path = []
        classes = []
        for image_class in os.listdir('cells_final'):
            for img in os.listdir(f'cells_final/{image_class}'):
                path.append(f'cells_final/{image_class}/{img}')
                classes.append(image_class)

        dataset_final = pd.DataFrame()
        dataset_final['path'] = path
        dataset_final['class'] = classes
        return dataset_final                
                          
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        image = cv2.imread(f'{self.dataset["path"].loc[idx]}')
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = cv2.resize(image, (32, 32), interpolation=cv2.INTER_CUBIC)
        
        #image = resize_with_padding(image, (32, 32))
        image = image.astype(np.float32)
        image = image/255.0
        image = self.transform(image = image)['image'] if self.transform is not None else image

        target = self.dataset['class'].loc[idx]

        if target.strip() == 'normal':
            target_ = [1, 0, 0, 0]
        elif target.strip() == 'inflammatory':
            target_ = [0, 1, 0, 0]
        elif target.strip() == 'tumor':
            target_ = [0, 0, 1, 0]
        elif target.strip() == 'other':
            target_ = [0, 0, 0, 1]
        else:
            print(target)
        
        image = F.to_tensor(image)
        

        return image.float(), torch.Tensor(np.array(target_, dtype=np.float32))

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
import torch
import torch.nn as nn
from einops import rearrange

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=3, patch_size=8, emb_size=256, img_size=32):
        super(PatchEmbedding, self).__init__()
        self.patch_size = patch_size
        self.proj = nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size)
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
        self.pos_embed = nn.Parameter(torch.randn((img_size // patch_size) ** 2 + 1, emb_size))
        
    def forward(self, x):
        B = x.shape[0]
        x = self.proj(x)  # (B, emb_size, h, w)
        x = rearrange(x, 'b e (h) (w) -> b (h w) e')
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embed
        return x

class TransformerEncoder(nn.Module):
    def __init__(self, emb_size, depth, n_heads, mlp_dim, dropout=0.1):
        super(TransformerEncoder, self).__init__()
        self.layers = nn.ModuleList([nn.TransformerEncoderLayer(d_model=emb_size, nhead=n_heads, dim_feedforward=mlp_dim, dropout=dropout) for _ in range(depth)])
        
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

class TransformerDecoder(nn.Module):
    def __init__(self, emb_size, depth, n_heads, mlp_dim, dropout=0.1):
        super(TransformerDecoder, self).__init__()
        self.layers = nn.ModuleList([nn.TransformerDecoderLayer(d_model=emb_size, nhead=n_heads, dim_feedforward=mlp_dim, dropout=dropout) for _ in range(depth)])
        
    def forward(self, tgt, memory):
        for layer in self.layers:
            tgt = layer(tgt, memory)
        return tgt

class VisionTransformerAutoencoder(nn.Module):
    def __init__(self, img_size=32, patch_size=8, in_channels=3, emb_size=256, depth=10, n_heads=4, mlp_dim=1024, dropout=0.1):
        super(VisionTransformerAutoencoder, self).__init__()
        self.patch_embedding = PatchEmbedding(in_channels, patch_size, emb_size, img_size)
        self.encoder = TransformerEncoder(emb_size, depth, n_heads, mlp_dim, dropout)
        self.decoder = TransformerDecoder(emb_size, depth, n_heads, mlp_dim, dropout)
        self.unpatchify = nn.ConvTranspose2d(emb_size, in_channels, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        # Encoder
        patches = self.patch_embedding(x)
        memory = self.encoder(patches) #this extract!

        # Decoder
        decoded_patches = self.decoder(patches, memory)
        decoded_patches = decoded_patches[:, 1:, :]  
        decoded_patches = rearrange(decoded_patches, 'b (h w) e -> b e h w', h=int(x.shape[-2] // self.patch_embedding.patch_size))

        # Reconstruct image
        reconstructed_img = self.unpatchify(decoded_patches)
        return reconstructed_img

# Example usage
img_size = 32
patch_size = 4
in_channels = 3
emb_size = 128
depth = 5
n_heads = 4
mlp_dim = 128
dropout = 0.0

model = VisionTransformerAutoencoder(img_size, patch_size, in_channels, emb_size, depth, n_heads, mlp_dim, dropout)

# Sample input
x = torch.randn(1, 3, 32, 32)
reconstructed_img = model(x)

print(reconstructed_img.shape)  # Should output torch.Size([1, 3, 32, 32])


torch.Size([1, 3, 32, 32])


In [6]:
model = model.to('cuda')

In [8]:
import time
import torch
import wandb
from torch.utils.data import DataLoader
import numpy as np

# Initialize WandB
run_name = f'conv_autoencoder_training_{datetime.datetime.now()}'

# Configuration
batch_size = 512
learning_rate = 1e-3
num_epochs = 100
early_stop_patience = 15  # Number of epochs to wait for improvement
run_path = f'training_checkpoints/{run_name}'

# DataLoader
trainset = ImageDataset(data_path='train_data')
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=5)

# Model, loss function, optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Early Stopping
best_loss = float('inf')
patience_counter = 0

for epoch in range(num_epochs):
    print('========================================')
    print(f'EPOCH: {epoch}')
    time_start = time.perf_counter()
    model.train()
    
    epoch_loss = 0
    for batch_idx, (inputs, _) in enumerate(trainloader):
        inputs = inputs.to('cuda')
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, inputs)  # Reconstruction loss
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    avg_loss = epoch_loss / len(trainloader)
    print(f'Epoch {epoch} Average Loss: {avg_loss}')
    
    
    if avg_loss < best_loss:
        best_loss = avg_loss
        torch.save(model.state_dict(), f'{run_path}.pth')
        print(f'Saved new best model with loss {best_loss}')
        patience_counter = 0  # Reset patience counter
    else:
        patience_counter += 1
    
    if patience_counter >= early_stop_patience:
        print(f'Early stopping at epoch {epoch} with best loss {best_loss}')
        break
    
    time_epoch = time.perf_counter() - time_start
    print(f'Epoch {epoch} time: {time_epoch/60} minutes')
    print('--------------------------------')

# Load the best model state dict
print(f'Loading model from {run_path}.pth')
model.load_state_dict(torch.load(f'{run_path}.pth'))

EPOCH: 0
Epoch 0 Average Loss: 0.6028989540665063
Saved new best model with loss 0.6028989540665063
Epoch 0 time: 1.60594351301667 minutes
--------------------------------
EPOCH: 1
Epoch 1 Average Loss: 0.5937266796335529
Saved new best model with loss 0.5937266796335529
Epoch 1 time: 1.6158317776333206 minutes
--------------------------------
EPOCH: 2
Epoch 2 Average Loss: 0.5931699441211058
Saved new best model with loss 0.5931699441211058
Epoch 2 time: 1.6028729696166617 minutes
--------------------------------
EPOCH: 3
Epoch 3 Average Loss: 0.5929557772556154
Saved new best model with loss 0.5929557772556154
Epoch 3 time: 1.5864048760500131 minutes
--------------------------------
EPOCH: 4
Epoch 4 Average Loss: 0.5929188313263871
Saved new best model with loss 0.5929188313263871
Epoch 4 time: 1.6152443586999956 minutes
--------------------------------
EPOCH: 5
Epoch 5 Average Loss: 0.5927559653721234
Saved new best model with loss 0.5927559653721234
Epoch 5 time: 1.6029080891666732

KeyboardInterrupt: 

In [None]:
model.load_state_dict(torch.load(f'{run_path}.pth'))
features = []
classes = []
paths = []
model.eval()
trainset = ImageDataset(data_path='train_data')
model = model.to('cuda')
with torch.no_grad():
    for idx in range(0, len(trainset)-1):
        img, cls = trainset[idx]
        classes.append(cls.cpu().detach().numpy())
        feature = model.encoder(img.to('cuda').reshape(1, 3, 32, 32))
        features.append(feature.cpu().detach().numpy())