In [None]:
import torch
from torch import nn
from torch.nn import functional as F
import math

class SelfAttention(nn.Module):
    def __init__(self, n_heads, d_embed, in_proj_bias=True, out_proj_bias=True):
        super().__init__()
        # This combines the Wq, Wk and Wv matrices into one matrix
        self.in_proj = nn.Linear(d_embed, 3 * d_embed, bias=in_proj_bias)
        # This one represents the Wo matrix
        self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
        self.n_heads = n_heads
        self.d_head = d_embed // n_heads

    def forward(self, x, causal_mask=False):
        # x: # (Batch_Size, Seq_Len, Dim)

        # (Batch_Size, Seq_Len, Dim)
        input_shape = x.shape 
        
        # (Batch_Size, Seq_Len, Dim)
        batch_size, sequence_length, d_embed = input_shape 

        # (Batch_Size, Seq_Len, H, Dim / H)
        interim_shape = (batch_size, sequence_length, self.n_heads, self.d_head) 

        # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim * 3) -> 3 tensor of shape (Batch_Size, Seq_Len, Dim)
        q, k, v = self.in_proj(x).chunk(3, dim=-1)
        
        # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, H, Dim / H) -> (Batch_Size, H, Seq_Len, Dim / H)
        q = q.view(interim_shape).transpose(1, 2)
        k = k.view(interim_shape).transpose(1, 2)
        v = v.view(interim_shape).transpose(1, 2)

        # (Batch_Size, H, Seq_Len, Dim / H) @ (Batch_Size, H, Dim / H, Seq_Len) -> (Batch_Size, H, Seq_Len, Seq_Len)
        weight = q @ k.transpose(-1, -2)
        
        if causal_mask:
            # Mask where the upper triangle (above the principal diagonal) is 1
            mask = torch.ones_like(weight, dtype=torch.bool).triu(1) 
            # Fill the upper triangle with -inf
            weight.masked_fill_(mask, -torch.inf) 
        
        # Divide by d_k (Dim / H). 
        # (Batch_Size, H, Seq_Len, Seq_Len) -> (Batch_Size, H, Seq_Len, Seq_Len)
        weight /= math.sqrt(self.d_head) 

        # (Batch_Size, H, Seq_Len, Seq_Len) -> (Batch_Size, H, Seq_Len, Seq_Len)
        weight = F.softmax(weight, dim=-1) 

        # (Batch_Size, H, Seq_Len, Seq_Len) @ (Batch_Size, H, Seq_Len, Dim / H) -> (Batch_Size, H, Seq_Len, Dim / H)
        output = weight @ v

        # (Batch_Size, H, Seq_Len, Dim / H) -> (Batch_Size, Seq_Len, H, Dim / H)
        output = output.transpose(1, 2) 

        # (Batch_Size, Seq_Len, H, Dim / H) -> (Batch_Size, Seq_Len, Dim)
        output = output.reshape(input_shape) 

        # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
        output = self.out_proj(output) 
        
        # (Batch_Size, Seq_Len, Dim)
        return output

class CrossAttention(nn.Module):
    def __init__(self, n_heads, d_embed, d_cross, in_proj_bias=True, out_proj_bias=True):
        super().__init__()
        self.q_proj   = nn.Linear(d_embed, d_embed, bias=in_proj_bias)
        self.k_proj   = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
        self.v_proj   = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
        self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
        self.n_heads = n_heads
        self.d_head = d_embed // n_heads
    
    def forward(self, x, y):
        # x (latent): # (Batch_Size, Seq_Len_Q, Dim_Q)
        # y (context): # (Batch_Size, Seq_Len_KV, Dim_KV) = (Batch_Size, 77, 768)

        input_shape = x.shape
        batch_size, sequence_length, d_embed = input_shape
        # Divide each embedding of Q into multiple heads such that d_heads * n_heads = Dim_Q
        interim_shape = (batch_size, -1, self.n_heads, self.d_head)
        
        # (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, Dim_Q)
        q = self.q_proj(x)
        # (Batch_Size, Seq_Len_KV, Dim_KV) -> (Batch_Size, Seq_Len_KV, Dim_Q)
        k = self.k_proj(y)
        # (Batch_Size, Seq_Len_KV, Dim_KV) -> (Batch_Size, Seq_Len_KV, Dim_Q)
        v = self.v_proj(y)

        # (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, H, Dim_Q / H) -> (Batch_Size, H, Seq_Len_Q, Dim_Q / H)
        q = q.view(interim_shape).transpose(1, 2) 
        # (Batch_Size, Seq_Len_KV, Dim_Q) -> (Batch_Size, Seq_Len_KV, H, Dim_Q / H) -> (Batch_Size, H, Seq_Len_KV, Dim_Q / H)
        k = k.view(interim_shape).transpose(1, 2) 
        # (Batch_Size, Seq_Len_KV, Dim_Q) -> (Batch_Size, Seq_Len_KV, H, Dim_Q / H) -> (Batch_Size, H, Seq_Len_KV, Dim_Q / H)
        v = v.view(interim_shape).transpose(1, 2) 
        
        # (Batch_Size, H, Seq_Len_Q, Dim_Q / H) @ (Batch_Size, H, Dim_Q / H, Seq_Len_KV) -> (Batch_Size, H, Seq_Len_Q, Seq_Len_KV)
        weight = q @ k.transpose(-1, -2)
        
        # (Batch_Size, H, Seq_Len_Q, Seq_Len_KV)
        weight /= math.sqrt(self.d_head)
        
        # (Batch_Size, H, Seq_Len_Q, Seq_Len_KV)
        weight = F.softmax(weight, dim=-1)
        
        # (Batch_Size, H, Seq_Len_Q, Seq_Len_KV) @ (Batch_Size, H, Seq_Len_KV, Dim_Q / H) -> (Batch_Size, H, Seq_Len_Q, Dim_Q / H)
        output = weight @ v
        
        # (Batch_Size, H, Seq_Len_Q, Dim_Q / H) -> (Batch_Size, Seq_Len_Q, H, Dim_Q / H)
        output = output.transpose(1, 2).contiguous()
        
        # (Batch_Size, Seq_Len_Q, H, Dim_Q / H) -> (Batch_Size, Seq_Len_Q, Dim_Q)
        output = output.view(input_shape)
        
        # (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, Dim_Q)
        output = self.out_proj(output)

        # (Batch_Size, Seq_Len_Q, Dim_Q)
        return output

In [None]:
!pip install tensorly

In [None]:
!pip install einops

In [None]:
import numpy as np
from PIL import Image,ImageFile,ImageOps
import os
import torch
import tensorly as tl
import cv2
from tensorly.decomposition import tucker,parafac
from sklearn.model_selection import train_test_split
import torch
from torch import nn
from torch import Tensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
import torch.nn.functional as F
from torch.utils.data import DataLoader,Dataset,TensorDataset,random_split,SubsetRandomSampler, ConcatDataset
import torch.optim as optim
import copy
from sklearn.model_selection import KFold
ImageFile.LOAD_TRUNCATED_IMAGES = True
from scipy.signal import wiener
torch.set_num_threads(1)

In [None]:
import numpy as np
import torch
import pickle
import copy
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
     
            self.counter = 0
    

In [None]:
class VAE_AttentionBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.groupnorm = nn.GroupNorm(32, channels)
        self.attention = SelfAttention(1, channels)
    
    def forward(self, x):
        # x: (Batch_Size, Features, Height, Width)

        residue = x 

        # (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height, Width)
        x = self.groupnorm(x)

        n, c, h, w = x.shape
        
        # (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height * Width)
        x = x.view((n, c, h * w))
        
        # (Batch_Size, Features, Height * Width) -> (Batch_Size, Height * Width, Features). Each pixel becomes a feature of size "Features", the sequence length is "Height * Width".
        x = x.transpose(-1, -2)
        
        # Perform self-attention WITHOUT mask
        # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
        x = self.attention(x)
        
        # (Batch_Size, Height * Width, Features) -> (Batch_Size, Features, Height * Width)
        x = x.transpose(-1, -2)
        
        # (Batch_Size, Features, Height * Width) -> (Batch_Size, Features, Height, Width)
        x = x.view((n, c, h, w))
        
        # (Batch_Size, Features, Height, Width) + (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height, Width) 
        x += residue

        # (Batch_Size, Features, Height, Width)
        return x 

class VAE_ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.groupnorm_1 = nn.GroupNorm(32, in_channels)
        self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)

        self.groupnorm_2 = nn.GroupNorm(32, out_channels)
        self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)

        if in_channels == out_channels:
            self.residual_layer = nn.Identity()
        else:
            self.residual_layer = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
    
    def forward(self, x):
        # x: (Batch_Size, In_Channels, Height, Width)

        residue = x

        # (Batch_Size, In_Channels, Height, Width) -> (Batch_Size, In_Channels, Height, Width)
        x = self.groupnorm_1(x)
        
        # (Batch_Size, In_Channels, Height, Width) -> (Batch_Size, In_Channels, Height, Width)
        x = F.silu(x)
        
        # (Batch_Size, In_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
        x = self.conv_1(x)
        
        # (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
        x = self.groupnorm_2(x)
        
        # (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
        x = F.silu(x)
        
        # (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
        x = self.conv_2(x)
        
        # (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
        return x + self.residual_layer(residue)

In [None]:
class VAE_Encoder(nn.Sequential):
    def __init__(self):
        super().__init__(
            # (Batch_Size, Channel, Height, Width) -> (Batch_Size, 128, Height, Width)
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            
             # (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height, Width)
            VAE_ResidualBlock(32, 32),
            
            # (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height / 2, Width / 2)
            nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=0), 
            
            # (Batch_Size, 256, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 2, Width / 2)
            VAE_ResidualBlock(32,64), 
            
            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=0), 
            
            # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
            VAE_ResidualBlock(64, 64), 
            VAE_ResidualBlock(64, 64), 
            
            VAE_AttentionBlock(64), 
            
            # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
            VAE_ResidualBlock(64, 64), 
            
            # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
            nn.GroupNorm(32, 64), 
            
            # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
            nn.SiLU(), 

            nn.Conv2d(64, 64,kernel_size=3, padding=1)
        )

    def forward(self, x):
        # x: (Batch_Size, Channel, Height, Width)
        # noise: (Batch_Size, 4, Height / 8, Width / 8)

        for module in self:

            if getattr(module, 'stride', None) == (2, 2):  # Padding at downsampling should be asymmetric (see #8)
                # Pad: (Padding_Left, Padding_Right, Padding_Top, Padding_Bottom).
                # Pad with zeros on the right and bottom.
                # (Batch_Size, Channel, Height, Width) -> (Batch_Size, Channel, Height + Padding_Top + Padding_Bottom, Width + Padding_Left + Padding_Right) = (Batch_Size, Channel, Height + 1, Width + 1)
                x = F.pad(x, (0, 1, 0, 1))
            
            x = module(x)
            
             # (Batch_Size, 8, Height / 8, Width / 8) -> two tensors of shape (Batch_Size, 4, Height / 8, Width / 8)
        mean, log_variance = torch.chunk(x, 2, dim=1)
        log_variance = torch.clamp(log_variance, -30, 20)
        variance = log_variance.exp()
        # (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 4, Height / 8, Width / 8)
        stdev = variance.sqrt()
        
        noise = torch.randn_like(mean)
        x = mean + stdev * noise
        
        # Scale by a constant
        # Constant taken from: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/configs/stable-diffusion/v1-inference.yaml#L17C1-L17C
        
        return x
        # (Batch_Size, 8, Height / 8, Width / 8) -> two tensors of shape (Batch_Size, 4, Height / 8, Width / 8)

In [None]:
class PatchEmbedding(nn.Module):
	def __init__(self, in_channels, patch_size, emb_size, img_size):#in_channels: int = 3, patch_size: int = 16, emb_size: int = 768, img_size: int = 224
		super().__init__()
		self.patch_size = patch_size
		self.nPatches = (img_size*img_size)//((patch_size)**2)
		self.projection = nn.Sequential(
			Rearrange('b c (h p1)(w p2) -> b (h w) (p1 p2 c)',p1 = patch_size,p2 = patch_size),
			nn.Linear(patch_size * patch_size * in_channels, emb_size)
		)
		self.cls_token = nn.Parameter(torch.randn(1,1, emb_size))
		#self.positions = nn.Parameter(torch.randn((img_size // patch_size) **2 + 1,emb_size))
                
	def forward(self, x):
		b,c,h,w = x.shape
		x = self.projection(x)
		cls_tokens = repeat(self.cls_token,'() n e -> b n e', b=b)#repeat the cls tokens for all patch set in 
		x = torch.cat([cls_tokens,x],dim=1)
		#x+=self.positions
		return x

In [None]:
class multiHeadAttention(nn.Module):
	def __init__(self, emb_size, heads, dropout):
		super().__init__()
		self.heads = heads
		self.emb_size = emb_size
		self.query = nn.Linear(emb_size,emb_size)
		self.key = nn.Linear(emb_size,emb_size)
		self.value = nn.Linear(emb_size,emb_size)
		self.drop_out = nn.Dropout(dropout)
		self.projection = nn.Linear(emb_size,emb_size)

	def forward(self,x):
		#splitting the single input int number of heads
		queries = rearrange(self.query(x),"b n (h d) -> b h n d", h = self.heads)
		keys = rearrange(self.key(x),"b n (h d) -> b h n d", h = self.heads)
		values = rearrange(self.value(x),"b n (h d) -> b h n d", h = self.heads)
		attention_maps = torch.einsum("bhqd, bhkd -> bhqk",queries,keys)
		scaling_value = self.emb_size**(1/2)
		attention_maps = F.softmax(attention_maps,dim=-1)/scaling_value
		attention_maps = self.drop_out(attention_maps)##might be deleted
		output = torch.einsum("bhal, bhlv -> bhav",attention_maps,values)
		output  = rearrange(output,"b h n d -> b n (h d)")
		output = self.projection(output)
		return output
class residual(nn.Module):
	def __init__(self,fn):
		super().__init__()
		self.fn = fn
	def forward(self,x):
		identity = x
		res = self.fn(x)
		out = res + identity
		return out

In [None]:
class DeepBlock(nn.Sequential):
	def __init__(self,emb_size:int =256 ,drop_out:float=0.0):#64
		super().__init__(
        		residual(
            			nn.Sequential(
                			nn.LayerNorm(emb_size),
                			multiHeadAttention(emb_size,2,drop_out),
                			nn.LayerNorm(emb_size)
            			)
        		)
    		)

class Classification(nn.Sequential):
	def __init__(self, emb_size:int=256, n_classes:int=2):
		super().__init__(
			# Reduce('b n e -> b e', reduction='mean'),
            nn.Dropout(0.01),
			nn.LayerNorm(emb_size), 
			nn.Linear(emb_size, n_classes))

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PatchEmbeddingVit(nn.Module):
    def __init__(self, in_channels, patch_size, emb_dim, num_patches):
        super().__init__()
        self.patch_embedding = nn.Conv2d(
            in_channels, 
            emb_dim, 
            kernel_size=patch_size, 
            stride=patch_size
        )
        self.num_patches = num_patches

    def forward(self, x):
        # x shape: [batch_size, channels, height, width]
        patches = self.patch_embedding(x)
        # patches shape: [batch_size, emb_dim, num_patches_h, num_patches_w]
        patches = patches.flatten(2)  # flatten spatial dimensions
        # patches shape: [batch_size, emb_dim, num_patches]
        return patches.transpose(1, 2)  # [batch_size, num_patches, emb_dim]

class MultiHeadAttentionVit(nn.Module):
    def __init__(self, emb_dim, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = emb_dim // num_heads
        
        self.query = nn.Linear(emb_dim, emb_dim)
        self.key = nn.Linear(emb_dim, emb_dim)
        self.value = nn.Linear(emb_dim, emb_dim)
        
        self.out_proj = nn.Linear(emb_dim, emb_dim)

    def forward(self, x):
        batch_size, num_patches, emb_dim = x.shape
        
        # Linear projections
        Q = self.query(x).view(batch_size, num_patches, self.num_heads, self.head_dim)
        K = self.key(x).view(batch_size, num_patches, self.num_heads, self.head_dim)
        V = self.value(x).view(batch_size, num_patches, self.num_heads, self.head_dim)
        
        # Transpose for attention computation
        Q = Q.transpose(1, 2)  # [batch_size, num_heads, num_patches, head_dim]
        K = K.transpose(1, 2)
        V = V.transpose(1, 2)
        
        # Compute attention scores
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attention_probs = F.softmax(attention_scores, dim=-1)
        
        # Apply attention
        context = torch.matmul(attention_probs, V)
        
        # Reshape and project
        context = context.transpose(1, 2).contiguous().view(batch_size, num_patches, emb_dim)
        return self.out_proj(context)

class TransformerBlockVit(nn.Module):
    def __init__(self, emb_dim, num_heads, mlp_dim, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(emb_dim)
        self.attn = MultiHeadAttentionVit(emb_dim, num_heads)
        self.norm2 = nn.LayerNorm(emb_dim)
        self.mlp = nn.Sequential(
            nn.Linear(emb_dim, mlp_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim, emb_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        # Multi-head self-attention
        x = x + self.attn(self.norm1(x))
        
        # MLP
        x = x + self.mlp(self.norm2(x))
        return x

class TinyViT(nn.Module):
    def __init__(self, 
                 in_channels=3, 
                 patch_size=16, 
                 emb_dim=192, 
                 num_heads=3, 
                 num_layers=4, 
                 num_classes=1000,
                 dropout=0.1,
                 image_size=32):  # Add image_size as a parameter
        super().__init__()
        
        # Dynamically calculate number of patches
        num_patches = (image_size // patch_size) ** 2
        
        # Patch Embedding
        self.patch_embed = PatchEmbeddingVit(
            in_channels=in_channels, 
            patch_size=patch_size, 
            emb_dim=emb_dim, 
            num_patches=num_patches
        )
        
        # Learnable class token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, emb_dim))
        
        # Position embedding - now created dynamically
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, emb_dim))
        
        # Rest of the initialization remains the same
         # Transformer Blocks
        self.blocks = nn.ModuleList([
            TransformerBlockVit(
                emb_dim=emb_dim, 
                num_heads=num_heads, 
                mlp_dim=emb_dim * 4, 
                dropout=dropout
            ) for _ in range(num_layers)
        ])
        
        # Layer Norm
        self.norm = nn.LayerNorm(emb_dim)
        
        # Classification head
        self.head = nn.Linear(emb_dim, num_classes)

    def forward(self, x):
        # Patch Embedding
        x = self.patch_embed(x)
        
        # Add class token
        batch_size = x.shape[0]
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        
        # Ensure positional embedding matches input size
        if x.size(1) != self.pos_embed.size(1):
            # Resize or adjust positional embedding if needed
            pos_embed = F.interpolate(
                self.pos_embed.transpose(1, 2), 
                size=x.size(1), 
                mode='linear', 
                align_corners=False
            ).transpose(1, 2)
        else:
            pos_embed = self.pos_embed
        
        # Add positional embedding
        x = x + pos_embed
        
        # Transformer Blocks
        for block in self.blocks:
            x = block(x)
        
        # Normalize
        x = self.norm(x)
        
        # Classification from class token
        return x[:, 0] 


In [None]:
class Model(nn.Module):
  def __init__(self,emb_size,drop_out, n_classes,in_channels,patch_size1,patch_size2,image_size):
    super().__init__()
    self.encoder1 = VAE_Encoder()
    self.encoder2 = TinyViT(
                    in_channels=in_channels,
        
                    num_classes=n_classes)
    # self.encoder = vae
    self.PatchEmbedding1 = PatchEmbedding(32,patch_size1,emb_size,64)
    self.PatchEmbedding2 = PatchEmbedding(32,patch_size2,emb_size,64)
    self.DeepBlock = DeepBlock(emb_size = 16)#Transformer()
    self.Classification = Classification(emb_size = 192,n_classes=2)
  def forward(self,x):
    x = self.encoder2(x)
    # print(x.shape)
      
#     print(type(x))
#     #x = x.view(16,32,56)
#     patch1 = self.PatchEmbedding1(x)
#     patch2 = self.PatchEmbedding2(x)
# #     print(patch1.shape)
# #     print(patch2.shape)
# #      Resize tensor2 along the second dimension to match the size of tensor1
#     desired_size = patch1.shape[1]  #patchEmbeddings1
#     indices = torch.linspace(0, patch2.shape[1] - 1, desired_size, device=device).long()
#     patch2_resized = torch.index_select(patch2, 1, indices)#.to(device)
# #     print(patch2_resized.shape)

#     # Concatenate the tensors along the second dimension (dim=1)
#     patchEmbeddings = torch.cat((patch2_resized, patch1), dim=1)#.to(device)
#     print(patchEmbeddings.shape)

#     DeepBlockOp = self.DeepBlock(patchEmbeddings)
#     print(DeepBlockOp.shape)
    # classificationOutput = self.Classification(DeepBlockOp)
    classificationOutput = self.Classification(x)
#     print(classificationOutput.shape)
    output = F.log_softmax(classificationOutput, dim=1)
    return output

In [None]:
def get_mag(f):
  dm_frequency_domain = np.fft.fftshift(f)
  dm_reduced_domain = dm_frequency_domain.copy()
  # Set a threshold for magnitude to retain only the most significant coefficients
  threshold = 0.001 * np.max(np.abs(dm_reduced_domain))
  dm_reduced_domain[np.abs(dm_reduced_domain) < threshold] = 0
  transformed_image = np.log(1 + np.abs(dm_reduced_domain))
  img_filtered = torch.tensor(transformed_image)
  return img_filtered
    

In [None]:
def readImage(imagePath):
  # Load image
  img = cv2.imread(imagePath)
  # img = cv2.resize(img, (256, 256))
  
  f1 = get_mag(np.fft.fft2(img[:,:,0]))
  f2 = get_mag(np.fft.fft2(img[:,:,1]))
  f3 = get_mag(np.fft.fft2(img[:,:,2]))  
  
  # result = torch.stack([f1,f2,f2,torch.tensor(img[:,:,0]),torch.tensor(img[:,:,1]),torch.tensor(img[:,:,2])],0)
  result = torch.stack([f1,f2,f2],0)
  return result

In [None]:
from tqdm import tqdm

class buildDataset(Dataset):
    def __init__(self,rootFolder):
        self.rootFolder = rootFolder
        self.images = []
        print("yo")
        for f in os.listdir(self.rootFolder):
#             f = os.path.join(self.rootFolder,f)
#             print(f)
            if f=='FAKE':
                f = os.path.join(self.rootFolder,f)
#             print(f)
                ind=0
                imgs1=[]
                for im in tqdm(os.listdir(f)):
                    ind+=1
                    if ind>40000:
                        break
                    im = os.path.join(f,im)
                    img = readImage(im)
                    self.images.append([img,0])
                    
            else:
                f = os.path.join(self.rootFolder,f)
#             print(f)
                ind=0
                for im in tqdm(os.listdir(f)):
                    ind+=1
                    if ind>40000:
                        break
                    im = os.path.join(f,im)
                    img = readImage(im)
                    self.images.append([img,1])
                    
                        
                    
        

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        input1 = self.images[index][0]
        label = self.images[index][1]
        return (input1,label)

In [None]:
from tqdm import tqdm

class buildDataset2(Dataset):
    def __init__(self,rootFolder):
        self.rootFolder = rootFolder
        self.images = []
        print("yo")
        for f in os.listdir(self.rootFolder):
#             f = os.path.join(self.rootFolder,f)
#             print(f)
            if f=='FAKE':
                f = os.path.join(self.rootFolder,f)
#             print(f)
                ind=0
                imgs1=[]
                for im in tqdm(os.listdir(f)):
                    ind+=1
                    if ind<40000 :
                        continue
                    im = os.path.join(f,im)
                    img = readImage(im)
                    self.images.append([img,0])
                    # if ind>30000:
                    #     break
            else:
                f = os.path.join(self.rootFolder,f)
#             print(f)
                ind=0
                for im in tqdm(os.listdir(f)):
                    ind+=1
                    if ind<40000:
                        continue
                    im = os.path.join(f,im)
                    img = readImage(im)
                    self.images.append([img,1])
                    # if ind>30000:
                    #     break
                        
                    
        

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        input1 = self.images[index][0]
        label = self.images[index][1]
        return (input1,label)

In [None]:
dataset = buildDataset('/kaggle/input/cifake-real-and-ai-generated-synthetic-images/train')
dataset_size = len(dataset)
trainDs = DataLoader(dataset,16,shuffle = True,pin_memory = False)

In [None]:
dataset = buildDataset('/kaggle/input/cifake-real-and-ai-generated-synthetic-images/train')
dataset_size = len(dataset)
valDs = DataLoader(dataset,16,shuffle = True,pin_memory = False)

In [None]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
# device = 'cpu'
model = Model(16,0,2,3,8,8,32)#emb_size,drop_out,target must be den n_classes,in_channels,patch size,image_size-14,0, 2,1,1,16,,,torch.load('tuckerTransformerModel_Stargan_2.pth')#
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr=0.00001)

In [None]:
def trainModel(model,criterion,optimizer,epochs,trainDs,valDs,path,patience=5):
    best_model_wts = copy.deepcopy(model.state_dict())

    batch_train_loss = []
    batch_train_acc = []
    batch_val_acc = []
    batch_val_loss = []
    early_stopping = EarlyStopping(patience=patience, verbose=True,delta = 0.01 , path=path)

    best_model_wts = copy.deepcopy(model.state_dict()) 
    max_acc=0
    for epoch in tqdm(range(epochs)):
        model.train()
        current_corrects = 0.0
        train_loss=[]
        for batchNum, (inputs1, labels) in enumerate(trainDs):
            input1 = inputs1.to(device).float()
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(input1)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss.append(loss.item())
            current_corrects += torch.sum(preds == labels.data)
        

        
        train_acc = current_corrects.double() / len(trainDs.sampler)
        batch_train_loss.append(np.mean(train_loss))
        batch_train_acc.append(train_acc)
        
        model.eval()
        current_corrects = 0.0
        val_loss=[]
        
        
        for batchNum, (inputs1, labels) in enumerate(valDs):
            input1 = inputs1.to(device).float()
            labels = labels.to(device)
            outputs = model(input1)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)
            val_loss.append(loss.item())
            current_corrects += torch.sum(preds == labels.data)

        val_acc = current_corrects.double() / len(valDs.sampler)
        batch_val_loss.append(np.mean(val_loss))
        batch_val_acc.append(val_acc)
        if max_acc < val_acc:
            best_model_wts = copy.deepcopy(model.state_dict()) 
            
        early_stopping(-1*val_acc, model)
        if early_stopping.early_stop:
            print("Early stopping")
            break     
        
        print('Epoch Number {}: Train- Loss{:.4f} Acc: {:.4f}'.format(epoch,batch_train_loss[-1],batch_train_acc[-1]))
        print('Epoch Number {}: Val- Loss{:.4f} Acc: {:.4f}'.format(epoch,batch_val_loss[-1],batch_val_acc[-1]))               
        
   
    return model,batch_train_loss,batch_train_acc,batch_val_loss,batch_val_acc,best_model_wts     				


In [None]:
import cv2
import numpy as np
import torch
import matplotlib.pyplot as plt

def readImage2(imagePath):
    # Load the image
    img = cv2.imread(imagePath)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # Convert to RGB for visualization

    # Compute FFT and magnitude for each channel
    f1 = get_mag(np.fft.fft2(img[:, :, 0], s=(256, 256)))
    f2 = get_mag(np.fft.fft2(img[:, :, 1], s=(256, 256)))
    f3 = get_mag(np.fft.fft2(img[:, :, 2], s=(256, 256)))


    # Stack the FFT magnitudes
    result = torch.stack([f1, f2, f3], 0)

    # Visualization
    plt.figure(figsize=(12, 6))

    # Original image
    plt.subplot(2, 4, 1)
    plt.imshow(img)
    plt.title("Original Image")
    plt.axis("off")

    # Magnitude spectra for each channel
    plt.subplot(2, 4, 2)
    plt.imshow(torch.log(1 + f1).numpy(), cmap="gray")  # Log scale for better visibility
    plt.title("FFT (Channel R)")
    plt.axis("off")

    plt.subplot(2, 4, 3)
    plt.imshow(torch.log(1 + f2).numpy(), cmap="gray")
    plt.title("FFT (Channel G)")
    plt.axis("off")

    plt.subplot(2, 4, 4)
    plt.imshow(torch.log(1 + f3).numpy(), cmap="gray")
    plt.title("FFT (Channel B)")
    plt.axis("off")

    plt.tight_layout()
    plt.show()

    return result


In [None]:
readImage2("/kaggle/input/cifake-real-and-ai-generated-synthetic-images/train/FAKE/1000 (10).jpg")

In [None]:
model,train_loss,train_acc,val_loss,val_acc,wts = trainModel(model,criterion,optimizer,50,trainDs,valDs,'/kaggle/working/')

In [None]:
# Extract weights
wts = model.get_weights()

# Save weights using pickle
with open(file_path, 'wb') as file:
    pickle.dump(wts, file)

# Load weights with pickle and set them to the model
# with open(file_path, 'rb') as file:
#     wts = pickle.load(file)
# model.set_weights(wts)


In [None]:
from tqdm import tqdm

class buildDataset3(Dataset):
    def __init__(self,rootFolder):
        self.rootFolder = rootFolder
        self.images = []
        print("yo")
        for f in os.listdir(self.rootFolder):
#             f = os.path.join(self.rootFolder,f)
#             print(f)
            if f=='FAKE':
                f = os.path.join(self.rootFolder,f)
#             print(f)
                ind=0
                imgs1=[]
                for im in tqdm(os.listdir(f)):
                    ind+=1
                    im = os.path.join(f,im)
                    img = readImage(im)
                    self.images.append([img,0])
            else:
                f = os.path.join(self.rootFolder,f)
#             print(f)
                ind=0
                for im in tqdm(os.listdir(f)):
                    ind+=1
                    im = os.path.join(f,im)
                    img = readImage(im)
                    self.images.append([img,1])
                        
                    
        

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        input1 = self.images[index][0]
        label = self.images[index][1]
        return (input1,label)

In [None]:
dataset = buildDataset3('/kaggle/input/cifake-real-and-ai-generated-synthetic-images/test')
dataset_size = len(dataset)
testDs = DataLoader(dataset,16,shuffle = True,pin_memory = False)

In [None]:
from sklearn.metrics import confusion_matrix, classification_report

model.eval()
current_corrects = 0.0
val_loss = []
all_preds = []
all_labels = []

for batchNum, (inputs1, labels) in enumerate(testDs):
    input1 = inputs1.to(device).float()
    labels = labels.to(device)
    outputs = model(input1)
    _, preds = torch.max(outputs, 1)
    
    # Store predictions and labels for evaluation
    all_preds.extend(preds.cpu().numpy())
    all_labels.extend(labels.cpu().numpy())
    
    # Calculate loss and corrects
    loss = criterion(outputs, labels)
    val_loss.append(loss.item())
    current_corrects += torch.sum(preds == labels.data)

# Calculate accuracy
val_acc = current_corrects.double() / len(testDs.sampler)
print("Test Accuracy:", val_acc)

# Calculate and print the confusion matrix
conf_matrix = confusion_matrix(all_labels, all_preds, labels=[0, 1])
print("Confusion Matrix:")
print(conf_matrix)

# Calculate precision, recall, F1-score, and support for each class
class_report = classification_report(all_labels, all_preds, labels=[0, 1], target_names=["Class 0", "Class 1"])
print("Classification Report:")
print(class_report)
