In [1]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("iamsouravbanerjee/animal-image-dataset-90-different-animals")

print("Path to dataset files:", path)

  from .autonotebook import tqdm as notebook_tqdm


Path to dataset files: C:\Users\bahaa\.cache\kagglehub\datasets\iamsouravbanerjee\animal-image-dataset-90-different-animals\versions\5


In [2]:
from datasets import load_dataset
from transformers import CLIPProcessor, CLIPModel
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, RandomHorizontalFlip, RandomRotation, RandomResizedCrop,CenterCrop
from torch.utils.data import DataLoader, random_split
import torch
from tqdm import tqdm
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau,StepLR
import datasets
import pandas as pd
from PIL import Image
import os
from torch import nn
from torchvision import datasets, transforms
from accelerate import Accelerator
from torch.distributed import init_process_group, destroy_process_group
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist ##these are for multiple gpu
accelerator = Accelerator()
import inspect

In [3]:
transform = Compose([
    # Randomly flip images horizontally
    RandomHorizontalFlip(p=0.5),
    Resize(128),
    CenterCrop(128),
    # Convert images to PyTorch tensors and scale to [0, 1]
    ToTensor(),
    Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
animal_dataset = datasets.ImageFolder(root=path,  # Specify the root directory of the dataset
                               transform=transform)  # Apply the defined transformations to the dataset


dataloader = torch.utils.data.DataLoader(animal_dataset, 32, shuffle=True, num_workers=2,drop_last=True)

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
import random
class GetPatchEmbeddings(nn.Module):
    def __init__(self, patch_size=16):
        super(GetPatchEmbeddings, self).__init__()
        self.patch_size = patch_size

    def divide_into_patches(self,image , patch_size=16):
        """
        Divides an image into non-overlapping patches.

        Args:
            image (torch.Tensor): Input image of shape (C, H, W).
            patch_size (int): Size of each square patch (patch_size x patch_size).

        Returns:
            torch.Tensor: A tensor of patches of shape (num_patches, C, patch_size, patch_size).
        """
        image = image.to(device)
        B,C, H, W = image.shape
        NUM_PATCHES = int(H*W/(patch_size**2))
        patches = image.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size).to(device)
        patches = patches.permute(0,2, 3, 1, 4, 5).reshape(B, NUM_PATCHES,C, patch_size, patch_size).to(device)
        return patches

    def sample_context_blocks(self,patches, m=49, image_size=(128,128), patch_size=16 ):
        """
        Randomly selects m continuous patches forming a square block (LxL).

        Args:
            patches (torch.Tensor): Tensor of patches of shape (num_patches, C, patch_size, patch_size).
            image_size (tuple): Original image size as (H, W).
            patch_size (int): Size of each patch (patch_size x patch_size).
            m (int): Number of continuous patches to select.

        Returns:
            list: Indices of selected patches.
            torch.Tensor: Selected patches of shape (m, C, patch_size, patch_size).
        """
        H, W = image_size
        n_patches_h = H // patch_size
        n_patches_w = W // patch_size
        patches = patches.to(device)
        # Compute side length of the square block (LxL = m)
        side_length = int(m ** 0.5)
        assert side_length ** 2 == m, "m must be a perfect square to form a square block."

        # Randomly select a top-left corner for the square block
        max_row = n_patches_h - side_length
        max_col = n_patches_w - side_length
        start_row = random.randint(0, max_row)
        start_col = random.randint(0, max_col)

        # Collect indices of the patches in the square block
        context_indices = []
        for _ in range(patches.shape[0]):
            context_indices.append([
            (start_row + i) * n_patches_w + (start_col + j)
            for i in range(side_length)
            for j in range(side_length)
        ])

        masked_patches = patches.clone().to(device)
        for j in range(patches.shape[0]):
            for i in range(len(patches[j])):
                if i not in context_indices[j]:
                    masked_patches[j][i] = torch.zeros_like(masked_patches[j][i])  # Mask unselected patches to zero

        return masked_patches, context_indices

    def sample_target_blocks(self,patches, m =9,  image_size=(128,128), patch_size=16):
        """
        Randomly selects m continuous patches forming a square block (LxL).

        Args:
            patches (torch.Tensor): Tensor of patches of shape (num_patches, C, patch_size, patch_size).
            image_size (tuple): Original image size as (H, W).
            patch_size (int): Size of each patch (patch_size x patch_size).
            m (int): Number of continuous patches to select.

        Returns:
            list: Indices of selected patches.
            torch.Tensor: Selected patches of shape (m, C, patch_size, patch_size).
        """
        patches = patches.to(device)
        H, W = image_size
        n_patches_h = H // patch_size
        n_patches_w = W // patch_size
        ar = (1.5 - 0.75) * torch.rand((1)) + 0.75 # aspect ratio

        # Compute side length of the target block
        side_length_h = int(ar*(m)**(0.5))
        side_length_w = int(m/side_length_h)

        # Randomly select a top-left corner for the square block
        max_row = n_patches_h - side_length_h
        max_col = n_patches_w - side_length_w
        start_row = random.randint(0, max_row)
        start_col = random.randint(0, max_col)

        # Collect indices of the patches in the square block
        target_indices = []
        for _ in range(patches.shape[0]):
            target_indices.append([
            (start_row + i) * n_patches_w + (start_col + j)
            for i in range(side_length_h)
            for j in range(side_length_w)
        ])

        masked_patches = patches.clone().to(device)
        for j in range(patches.shape[0]):
            for i in range(len(patches[j])):
                if i not in target_indices[j]:
                    masked_patches[j][i] = torch.zeros_like(masked_patches[j][i])  # Mask unselected patches to zero

        return masked_patches, target_indices

    def remove_overlaps(self,context_blocks, context_indices, NUM_TARGETS=4):

        TARGET_INDICES = []
        for i in  range(NUM_TARGETS):
            
            target,indices = context_blocks.to(device),context_indices
            TARGET_INDICES.append(indices)

            if i == 0 :
                TARGET_BLOCKS = target.clone().to(device)
            else:
                TARGET_BLOCKS = torch.cat((TARGET_BLOCKS,target),dim=1).to(device)

            for i in range(len(indices)):
                for idx in indices[i]:
                    context_indices_i = set(context_indices[i])
                    if idx  in context_indices_i:
                        context_blocks[i][idx] = torch.zeros_like(context_blocks[i][idx])
                        context_indices[i].remove(idx)
            context_blocks, TARGET_BLOCKS = context_blocks.to(device), TARGET_BLOCKS.to(device)
        return context_blocks, context_indices ,TARGET_BLOCKS.reshape(-1,4,64,3,16,16), TARGET_INDICES

    def forward(self, x):

        patches = self.divide_into_patches(x)
        context, context_indices = self.sample_context_blocks(patches)
        context, context_indices,targets, target_indices  = self.remove_overlaps(context, context_indices)
        return context, context_indices, targets, target_indices



In [5]:
import numpy as np
class GetPositionalEmbeddings(nn.Module):
    def __init__(self):
        super(GetPositionalEmbeddings, self).__init__()

    def get_1d_sincos_pos_embed_from_grid(self,embed_dim, pos):
        """
        embed_dim: output dimension for each position
        pos: a list of positions to be encoded: size (M,)
        out: (M, D)
        """
        assert embed_dim % 2 == 0
        omega = np.arange(embed_dim // 2, dtype=float)
        omega /= embed_dim / 2.
        omega = 1. / 10000**omega   # (D/2,)

        pos = pos.reshape(-1)   # (M,)
        out = np.einsum('m,d->md', pos, omega)   # (M, D/2), outer product

        emb_sin = np.sin(out)  # (M, D/2)
        emb_cos = np.cos(out)  # (M, D/2)

        emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
        return emb

    def get_1d_sincos_pos_embed(self,embed_dim, grid_size, cls_token=False):
        """
        grid_size: int of the grid length
        return:
        pos_embed: [grid_size, embed_dim] or [1+grid_size, embed_dim] (w/ or w/o cls_token)
        """
        grid = np.arange(grid_size, dtype=float)
        pos_embed = self.get_1d_sincos_pos_embed_from_grid(embed_dim, grid)
        if cls_token:
            pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
        return pos_embed


    def get_2d_sincos_pos_embed_from_grid(self,embed_dim, grid):
        assert embed_dim % 2 == 0

        # use half of dimensions to encode grid_h
        emb_h = self.get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
        emb_w = self.get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)

        emb = np.concatenate([emb_h, emb_w], axis=1)  # (H*W, D)
        return emb

    def get_2d_sincos_pos_embed(self,embed_dim, grid_size, cls_token=False):
        """
        grid_size: int of the grid height and width
        return:
        pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
        """
        grid_h = np.arange(grid_size, dtype=float)
        grid_w = np.arange(grid_size, dtype=float)
        grid = np.meshgrid(grid_w, grid_h)  # here w goes first
        grid = np.stack(grid, axis=0)

        grid = grid.reshape([2, 1, grid_size, grid_size])
        pos_embed = self.get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
        if cls_token:
            pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
        return pos_embed

    def forward(self, grid_size):
        pos_emb = self.get_2d_sincos_pos_embed(768, grid_size)
        return pos_emb

In [6]:
from torch.nn import functional as F
class CausalSelfAttention(nn.Module):

    def __init__(self, n_embd,n_head):
        super().__init__()
        assert n_embd % n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(n_embd, 3 * n_embd)
        # output projection
        self.c_proj = nn.Linear(n_embd,n_embd)
        
        # regularization
        self.n_head = n_head
        self.n_embd = n_embd
        self.dropout=nn.Dropout(0.1)
        

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        # nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
        # e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformer
        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True) # flash attention
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
        # output projection
        y = self.c_proj(y)
        y = self.dropout(y)
        #print(f"this is for self : {y.shape}")
        return y

In [7]:
class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd,ff_hid_dim):
        super().__init__()
        self.c_fc    = nn.Linear(n_embd, ff_hid_dim)
        self.gelu    = nn.GELU(approximate='tanh')
        self.c_proj  = nn.Linear(ff_hid_dim, n_embd)
        
        
    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
       
        
        return x
class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head,ff_fid_dim):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = CausalSelfAttention(n_embd,n_head)
        self.ffwd = FeedFoward(n_embd,ff_fid_dim)
        
        
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        
        x = x + self.ffwd(self.ln2(x))
        return x

In [8]:
image_size=128
from typing import Optional
class VisionTransformer(nn.Module):
    def __init__(self,patchembedder,positionembedder,image_size ,patch_size, in_channels, embed_dim, num_heads,
                 num_layers, ff_hid_dim, max_len: int = 512, target = False):
        super(VisionTransformer, self).__init__()

        self.embed_dim = embed_dim
        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) ** 2
        self.patch_embedding = patchembedder
        self.position_embedding = positionembedder
        self.target = target

        # Transformer encoder layers
        self.encoder_layers = nn.ModuleList([
            Block(embed_dim, num_heads,ff_hid_dim)
            for _ in range(num_layers)
        ]).to(device)


    def forward(self, x: torch.Tensor, target_rep: Optional[torch.Tensor] = None, target_indices :Optional [list] = None):
        # Step 1: Patch embedding
        if self.target == False:
            batch_size = x.size(0)
            x,_,target,target_indices = self.patch_embedding(x) # (B, embed_dim, H/patch_size, W/patch_size)

            x = x.flatten(2) # (B, num_patches, embed_dim)

            # Step 2: Add positional embedding
            pos_emb = torch.from_numpy(self.position_embedding(8)).unsqueeze(dim=0).repeat_interleave(repeats = 32, dim=0).to(device)
            pos_emb = pos_emb.float()
            x = x +  pos_emb

            # Step 3: Pass through transformer layers
            for layer in self.encoder_layers:
                x = layer(x)

            # Step 4: Return output for each patch
            x, target  =x.to(device), target.to(device)
            return x , target, target_indices # (B, num_patches, embed_dim)
        else:
            batch_size = x.size(0)
            target_representations = []
            target_blocks = target_rep.to(device)
            for i in range(target_blocks.shape[1]):
                x = target_blocks[:,i,:,:].flatten(2) # (B, num_patches, embed_dim)

                # Step 2: Add positional embedding
                pos_emb = torch.from_numpy(self.position_embedding(8)).unsqueeze(dim=0).repeat_interleave(repeats = 32, dim=0).to(device)
                pos_emb = pos_emb.float()
                x = x+  pos_emb

                # Step 3: Pass through transformer layers
                for layer in self.encoder_layers:
                    x = layer(x)

                target_representations.append(x)
            return target_representations  # (B, num_patches, embed_dim)
   

In [9]:
get_positional_embeddings = GetPositionalEmbeddings().to(device)
patchembedder  = GetPatchEmbeddings().to(device)
context_enocder = VisionTransformer(
        patchembedder,
        get_positional_embeddings,
        image_size=128,
        patch_size=16,
        in_channels=3,
        embed_dim=768,
        num_heads=8,
        num_layers=6,
        ff_hid_dim=768,
        
        target=False
    )
target_encoder = VisionTransformer(
        patchembedder,
        get_positional_embeddings,
        image_size=128,
        patch_size=16,
        in_channels=3,
        embed_dim=768,
        num_heads=8,
        num_layers=6,
        ff_hid_dim=768,
        
        target=True
    )

In [10]:
class VisionTransformer_predictor(nn.Module):
    def __init__(self,positionembedder, patch_size, image_size, in_channels, embed_dim, num_heads,
                 num_layers, ff_hid_dim, max_len: int = 512):
        super(VisionTransformer_predictor, self).__init__()

        self.embed_dim = embed_dim
        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) ** 2
        self.position_embedding = positionembedder.to(device)

        self.encoder_layers = nn.ModuleList([
            Block(embed_dim, num_heads,ff_hid_dim)
            for _ in range(num_layers)
        ]).to(device)
        self.main = nn.Sequential(
            nn.Linear(1, 64),
            nn.ReLU(),
            nn.Linear(64, 768)
        )
    def get_mask_embeddings(self,target_list : list):

        pos_emb = torch.from_numpy(self.position_embedding(8)).unsqueeze(dim=0).repeat_interleave(repeats = 32, dim=0).to(device)
        pos_emb = pos_emb.float()
        for i in range(len(target_list)):
            pos_masks = self.position_embedding.get_1d_sincos_pos_embed_from_grid(768, pos = np.array(target_list[i]))
            pos_masks = torch.from_numpy(pos_masks).to(device)
            pos_masks = torch.cat([pos_masks, torch.zeros(pos_emb[i].shape[0] - pos_masks.shape[0], 768).to(device)], dim=0).to(device)
            pos_emb[i] =  pos_emb[i] + pos_masks
        return pos_emb

    def forward(self, x: torch.Tensor, target_lists: list):
        x = x.to(device)
        pos_emb = self.get_mask_embeddings(target_lists)
        pos_emb = pos_emb.float()
        x = x + pos_emb

        # Step 3: Pass through transformer layers
        for layer in self.encoder_layers:
            x = layer(x)

        return x
   

In [11]:
predictor = VisionTransformer_predictor(
        get_positional_embeddings,
        image_size=128,
        patch_size=16,
        in_channels=3,
        embed_dim=768,
        num_heads=4,
        num_layers=6,
        ff_hid_dim=384,

    )

In [12]:
criterion = torch.nn.CrossEntropyLoss()


In [13]:
criterion = torch.nn.CrossEntropyLoss()




In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm

# Cihaz tanımlaması (GPU varsa kullanır)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train_iJEPA(context_model, target_model, predictor_model, animal_dataset, num_epochs=10):
    # Modelleri cihaza taşıyoruz
    context_model = context_model.to(device)
    target_model = target_model.to(device)
    predictor_model = predictor_model.to(device)
    
    # DataLoader oluşturuluyor
    dataloader = torch.utils.data.DataLoader(
        animal_dataset, 
        batch_size=32, 
        shuffle=True, 
        num_workers=4, 
        drop_last=True
    )
    
    # Optimizatörler
    context_opt = torch.optim.AdamW(context_model.parameters(), lr=1e-4)
    predictor_opt = torch.optim.AdamW(predictor_model.parameters(), lr=1e-4)
    scheduler = StepLR(predictor_opt, step_size=3, gamma=0.1)

    # target_model parametrelerini donduruyoruz
    for param in target_model.parameters():
        param.requires_grad = False

    # Loss fonksiyonunu tanımlıyoruz: 1 - cosine similarity (ortalaması)
    # Böylece loss her zaman 0 veya pozitif olacak.
    criterion = lambda pred, target: (1 - F.cosine_similarity(pred, target, dim=-1)).mean()
    
    train_losses = []
    context_model.train()
    predictor_model.train()
    
    for epoch in range(num_epochs):
        total_loss = 0.0
        for idx, (x, _) in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")):
            x = x.to(device)
            
            # context_model'den context, target ve target_indices elde ediliyor
            context, target, target_indices = context_model(x)
            
            # target_model'den hedef temsiller (target representations) alınıyor
            target_representations = target_model(x, target, target_indices)
            
            loss_accum = 0.0
            # Her bir hedef temsil için loss hesaplanıyor
            for i in range(len(target_representations)):
                # predictor_model ile context temsilleri elde ediliyor.
                # context.detach() ile geriye yayılımın target_model'e gitmesi engelleniyor.
                context_representations = predictor_model(context.detach(), target_indices[i])
                
                # Loss: 1 - cosine_similarity, böylece negatif değer oluşması engelleniyor.
                loss_i = criterion(context_representations, target_representations[i])
                loss_accum += loss_i
                
            loss = loss_accum / len(target_representations)
            
            # Geriye yayılım
            loss.backward()
            
            # Loss toplamına ekleme (skaler olarak)
            total_loss += loss.item()
            
            # Parametre güncellemesi
            context_opt.step()
            predictor_opt.step()
            context_opt.zero_grad()
            predictor_opt.zero_grad()
            
            # CUDA kullanıyorsak senkronizasyon
            if device.type == "cuda":
                torch.cuda.synchronize()
        
        # Öğrenme oranı güncellemesi
        scheduler.step()
        
        avg_loss = total_loss / len(dataloader)
        train_losses.append(avg_loss)
        print(f"Epoch {epoch+1} | Loss: {avg_loss:.6f}")
    
    print("Finished Training")
    return train_losses

In [15]:
"""def train_iJEPA(context_model, target_model, predictor_model, num_epochs = 10):
    context_model  = context_model.to(device)
    target_model = target_model.to(device)
    predictor_model = predictor_model.to(device)
    dataloader = torch.utils.data.DataLoader(animal_dataset, batch_size =32 , shuffle = True, num_workers=4,drop_last=True)
    
    context_opt= torch.optim.AdamW(context_model.parameters(), lr=1e-4)
    predictor_opt = torch.optim.AdamW(predictor_model.parameters(), lr=1e-4)
    scheduler = StepLR(predictor_opt, step_size=3, gamma=0.1)

    for param in target_model.parameters():
        param.requires_grad = False
    train_losses = []
    context_model.train()
    predictor_model.train()
    train_losses = []
    for epochs in range(num_epochs):
        
        total_loss=0
        for idx,(x,_) in enumerate(tqdm(dataloader)):
            with torch.autograd.set_detect_anomaly(True):
                x = x.to(device)
                
                context, target, target_indices = context_model(x)
                
                target_representations = target_model(x,target, target_indices)
                loss_accum=0
                for i in range(len(target_representations)):
                    context_representations = predictor_model(context.detach(), target_indices[i])
                
               
                    loss_çakma = criterion(context_representations,target_representations[i])
                    loss_accum+=loss_çakma
                loss=loss_accum/len(target_representations)
                loss.backward()
                
                total_loss += loss
                context_opt.step()
                predictor_opt.step()
                context_opt.zero_grad()
                predictor_opt.zero_grad()
                if device == "cuda":
                    torch.cuda.synchronize()
        scheduler.step()
        avg_loss = total_loss / len(dataloader)
        train_losses.append(avg_loss)
       
        
        print(f"Epoch {epochs} | Loss: {avg_loss.item():.6f}")



    print('finished Training')
    return train_losses"""

'def train_iJEPA(context_model, target_model, predictor_model, num_epochs = 10):\n    context_model  = context_model.to(device)\n    target_model = target_model.to(device)\n    predictor_model = predictor_model.to(device)\n    dataloader = torch.utils.data.DataLoader(animal_dataset, batch_size =32 , shuffle = True, num_workers=4,drop_last=True)\n    \n    context_opt= torch.optim.AdamW(context_model.parameters(), lr=1e-4)\n    predictor_opt = torch.optim.AdamW(predictor_model.parameters(), lr=1e-4)\n    scheduler = StepLR(predictor_opt, step_size=3, gamma=0.1)\n\n    for param in target_model.parameters():\n        param.requires_grad = False\n    train_losses = []\n    context_model.train()\n    predictor_model.train()\n    train_losses = []\n    for epochs in range(num_epochs):\n        \n        total_loss=0\n        for idx,(x,_) in enumerate(tqdm(dataloader)):\n            with torch.autograd.set_detect_anomaly(True):\n                x = x.to(device)\n                \n          

In [18]:
IMG_SIZE = 128
transform = transforms.Compose([
    # Randomly flip images horizontally
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.Resize(IMG_SIZE),
    transforms.CenterCrop(IMG_SIZE),
    # Convert images to PyTorch tensors and scale to [0, 1]
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])



animal_dataset = datasets.ImageFolder(root=path,  # Specify the root directory of the dataset
                               transform=transform)  # Apply the defined transformations to the dataset



In [21]:
trained=train_iJEPA(context_enocder,target_encoder,predictor,animal_dataset)

Epoch 1/10:   0%|          | 0/168 [00:00<?, ?it/s]

Epoch 1/10: 100%|██████████| 168/168 [01:13<00:00,  2.29it/s]


Epoch 1 | Loss: 0.087288


Epoch 2/10:  15%|█▍        | 25/168 [00:20<01:59,  1.20it/s]


KeyboardInterrupt: 