In [3]:
!pip install torch torchvision tqdm h5py




In [4]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm  # For progress bar
import time
import os
from PIL import Image
import h5py

In [5]:
# Custom Dataset for PlantVillage with added error handling for debugging
class PlantVillageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.skipped_files = 0  # Count how many files are skipped

        for subdir, _, files in os.walk(root_dir):
            print(f"Directory: {subdir}, Number of files: {len(files)}")
            print("First 5 files:", files[:5])  # Print first 5 files in each directory
            
            for file in files:
                file_path = os.path.join(subdir, file)
                try:
                    # Test if the file can be opened as an image (ignore file extension)
                    with Image.open(file_path) as img:
                        self.image_paths.append(file_path)
                except Exception as e:
                    # Log the error if the file cannot be opened as an image
                    print(f"Error opening file {file_path}: {e}")
                    self.skipped_files += 1

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image

In [6]:
# Data augmentation for self-supervised learning
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


In [7]:
# Load PlantVillage dataset
root_dir = '/kaggle/input/plantvillage-dataset/PlantVillage_dataset_2' # Update with actual path
dataset = PlantVillageDataset(root_dir=root_dir, transform=train_transform)
train_loader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=8)

Directory: /kaggle/input/plantvillage-dataset/PlantVillage_dataset_2, Number of files: 43460
First 5 files: ['64e034cb-54bb-4434-b49a-42efe260c966___CREC_HLB 7803.JPG', '0ade19e4-c48b-4e58-bddf-153d44d48c3e___Com.G_SpM_FL 9449.JPG', '919505a0-7e06-4a89-8faf-f308bd644a6a___RS_NLB 4141.JPG', 'f218fcff-fc71-4db2-9ed8-52d0871d5a67___CREC_HLB 7344.JPG', 'd08a1c48-3360-40e3-9c3d-e47c2812bed2___UF.GRC_YLCV_Lab 01920.JPG']




In [8]:
# Print dataset size
print(f"Total number of images in the dataset: {len(dataset)}")
# Calculate expected number of batches
expected_batches = len(dataset) // 64
print(f"Expected number of batches: {expected_batches}")

Total number of images in the dataset: 43460
Expected number of batches: 679


In [9]:
# Load MobileNet with ImageNet weights, removing the classifier for self-supervised learning
mobilenet = models.mobilenet_v2(pretrained=True)
mobilenet.classifier = nn.Identity()


Downloading: "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v2-b0353104.pth
100%|██████████| 13.6M/13.6M [00:00<00:00, 111MB/s] 


In [10]:
# Define MoCo-style Model with MobileNet as the encoder and dropout
class MoCoModelWithDropout(nn.Module):
    def __init__(self, base_encoder, dropout_prob=0.5):
        super(MoCoModelWithDropout, self).__init__()
        self.encoder_q = nn.Sequential(
            base_encoder,  # Query encoder
            nn.Dropout(p=dropout_prob)  # Add dropout layer
        )
        self.encoder_k = nn.Sequential(
            base_encoder,  # Key encoder
            nn.Dropout(p=dropout_prob)  # Add dropout layer
        )

    def forward(self, x_q, x_k):
        q = self.encoder_q(x_q)  # Encode the query with dropout
        k = self.encoder_k(x_k)  # Encode the key with dropout
        return q, k



In [11]:
# Initialize MoCo with MobileNet as the encoder and dropout
model_with_dropout = MoCoModelWithDropout(mobilenet, dropout_prob=0.5).cuda()

In [12]:
# Define the AdamW optimizer
optimizer = torch.optim.AdamW(model_with_dropout.parameters(), lr=0.01, weight_decay=0.01)


In [13]:
# Define the InfoNCE (contrastive) loss function with accuracy calculation
def contrastive_loss_with_accuracy(q, k, temperature=0.07):
    # Normalize the queries and keys
    q = nn.functional.normalize(q, dim=1)
    k = nn.functional.normalize(k, dim=1)

    # Compute positive logits (similarity between q and k)
    pos_logits = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)

    # Compute negative logits (similarity between q and all keys)
    neg_logits = torch.mm(q, k.T)

    # Combine the logits and apply softmax
    logits = torch.cat([pos_logits, neg_logits], dim=1)
    logits /= temperature
    labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()

    # Compute InfoNCE loss
    loss = nn.CrossEntropyLoss()(logits, labels)

    # Calculate accuracy: Check if the highest logit is the positive one
    _, preds = torch.max(logits, dim=1)
    accuracy = (preds == labels).float().mean().item()

    return loss, accuracy

In [14]:
# Training loop for 100 epochs with accurate loss and accuracy calculation
num_epochs = 100
for epoch in range(num_epochs):
    model_with_dropout.train()
    total_loss = 0
    total_accuracy = 0
    start_time = time.time()
    
    # Progress bar for the current epoch
    progress_bar = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{num_epochs}]")
    
    for batch in progress_bar:
        x_q = batch.cuda()
        x_k = batch.cuda()  # Ideally, different augmentations should be applied here for MoCo

        # Forward pass through MoCo model
        q, k = model_with_dropout(x_q, x_k)

        # Calculate contrastive loss and accuracy
        loss, accuracy = contrastive_loss_with_accuracy(q, k)

        # Backpropagation and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_accuracy += accuracy
        progress_bar.set_postfix(loss=loss.item(), accuracy=accuracy * 100)
    
    avg_loss = total_loss / len(train_loader)
    avg_accuracy = total_accuracy / len(train_loader)
    elapsed_time = time.time() - start_time
    print(f"Epoch [{epoch+1}/{num_epochs}], Avg Loss: {avg_loss:.4f}, Accuracy: {avg_accuracy*100:.2f}%, Time: {elapsed_time:.2f}s")

Epoch [1/100]: 100%|██████████| 680/680 [05:22<00:00,  2.11it/s, accuracy=75, loss=0.72]   


Epoch [1/100], Avg Loss: 1.1294, Accuracy: 71.18%, Time: 322.01s


Epoch [2/100]: 100%|██████████| 680/680 [05:19<00:00,  2.13it/s, accuracy=50, loss=0.702]  


Epoch [2/100], Avg Loss: 1.0417, Accuracy: 72.88%, Time: 319.44s


Epoch [3/100]: 100%|██████████| 680/680 [05:18<00:00,  2.13it/s, accuracy=100, loss=0.697] 


Epoch [3/100], Avg Loss: 0.9643, Accuracy: 75.58%, Time: 318.81s


Epoch [4/100]: 100%|██████████| 680/680 [05:19<00:00,  2.13it/s, accuracy=0, loss=0.699]   


Epoch [4/100], Avg Loss: 0.9447, Accuracy: 75.80%, Time: 319.04s


Epoch [5/100]: 100%|██████████| 680/680 [05:18<00:00,  2.13it/s, accuracy=100, loss=0.703] 


Epoch [5/100], Avg Loss: 0.9169, Accuracy: 76.14%, Time: 318.67s


Epoch [6/100]: 100%|██████████| 680/680 [05:17<00:00,  2.14it/s, accuracy=50, loss=0.697]  


Epoch [6/100], Avg Loss: 0.8985, Accuracy: 77.01%, Time: 317.87s


Epoch [7/100]: 100%|██████████| 680/680 [05:17<00:00,  2.14it/s, accuracy=100, loss=0.695] 


Epoch [7/100], Avg Loss: 0.8961, Accuracy: 76.99%, Time: 317.65s


Epoch [8/100]: 100%|██████████| 680/680 [05:17<00:00,  2.14it/s, accuracy=100, loss=0.699] 


Epoch [8/100], Avg Loss: 0.8748, Accuracy: 77.59%, Time: 317.59s


Epoch [9/100]: 100%|██████████| 680/680 [05:17<00:00,  2.14it/s, accuracy=75, loss=0.695]  


Epoch [9/100], Avg Loss: 0.8766, Accuracy: 77.31%, Time: 317.15s


Epoch [10/100]: 100%|██████████| 680/680 [05:17<00:00,  2.14it/s, accuracy=75, loss=0.699]  


Epoch [10/100], Avg Loss: 0.8758, Accuracy: 77.37%, Time: 317.07s


Epoch [11/100]: 100%|██████████| 680/680 [05:16<00:00,  2.15it/s, accuracy=100, loss=0.696] 


Epoch [11/100], Avg Loss: 0.8731, Accuracy: 77.32%, Time: 316.66s


Epoch [12/100]: 100%|██████████| 680/680 [05:16<00:00,  2.15it/s, accuracy=100, loss=0.698] 


Epoch [12/100], Avg Loss: 0.8774, Accuracy: 77.37%, Time: 316.79s


Epoch [13/100]: 100%|██████████| 680/680 [05:16<00:00,  2.15it/s, accuracy=50, loss=0.696]  


Epoch [13/100], Avg Loss: 0.8738, Accuracy: 77.45%, Time: 316.28s


Epoch [14/100]: 100%|██████████| 680/680 [05:15<00:00,  2.15it/s, accuracy=75, loss=0.707]  


Epoch [14/100], Avg Loss: 0.8694, Accuracy: 77.76%, Time: 315.87s


Epoch [15/100]: 100%|██████████| 680/680 [05:15<00:00,  2.15it/s, accuracy=75, loss=0.696]  


Epoch [15/100], Avg Loss: 0.8711, Accuracy: 77.45%, Time: 315.82s


Epoch [16/100]: 100%|██████████| 680/680 [05:15<00:00,  2.15it/s, accuracy=100, loss=0.7]   


Epoch [16/100], Avg Loss: 0.8664, Accuracy: 77.93%, Time: 315.59s


Epoch [17/100]: 100%|██████████| 680/680 [05:15<00:00,  2.15it/s, accuracy=50, loss=0.695]  


Epoch [17/100], Avg Loss: 0.8667, Accuracy: 77.77%, Time: 315.73s


Epoch [18/100]: 100%|██████████| 680/680 [05:15<00:00,  2.15it/s, accuracy=75, loss=0.697]  


Epoch [18/100], Avg Loss: 0.8650, Accuracy: 77.77%, Time: 315.97s


Epoch [19/100]: 100%|██████████| 680/680 [05:15<00:00,  2.16it/s, accuracy=75, loss=0.695]  


Epoch [19/100], Avg Loss: 0.8619, Accuracy: 77.92%, Time: 315.29s


Epoch [20/100]: 100%|██████████| 680/680 [05:15<00:00,  2.16it/s, accuracy=75, loss=0.696]  


Epoch [20/100], Avg Loss: 0.8501, Accuracy: 78.11%, Time: 315.09s


Epoch [21/100]: 100%|██████████| 680/680 [05:14<00:00,  2.16it/s, accuracy=50, loss=0.695]  


Epoch [21/100], Avg Loss: 0.8640, Accuracy: 77.60%, Time: 314.95s


Epoch [22/100]: 100%|██████████| 680/680 [05:15<00:00,  2.16it/s, accuracy=75, loss=0.697]  


Epoch [22/100], Avg Loss: 0.8524, Accuracy: 78.01%, Time: 315.05s


Epoch [23/100]: 100%|██████████| 680/680 [05:15<00:00,  2.16it/s, accuracy=100, loss=0.695] 


Epoch [23/100], Avg Loss: 0.8520, Accuracy: 78.37%, Time: 315.41s


Epoch [24/100]: 100%|██████████| 680/680 [05:14<00:00,  2.16it/s, accuracy=100, loss=0.694] 


Epoch [24/100], Avg Loss: 0.8616, Accuracy: 78.15%, Time: 314.70s


Epoch [25/100]: 100%|██████████| 680/680 [05:14<00:00,  2.16it/s, accuracy=75, loss=0.695]  


Epoch [25/100], Avg Loss: 0.8548, Accuracy: 77.78%, Time: 314.91s


Epoch [26/100]: 100%|██████████| 680/680 [05:14<00:00,  2.16it/s, accuracy=25, loss=0.695]  


Epoch [26/100], Avg Loss: 0.8461, Accuracy: 77.75%, Time: 314.63s


Epoch [27/100]: 100%|██████████| 680/680 [05:14<00:00,  2.16it/s, accuracy=75, loss=0.694]  


Epoch [27/100], Avg Loss: 0.8536, Accuracy: 77.75%, Time: 314.55s


Epoch [28/100]: 100%|██████████| 680/680 [05:15<00:00,  2.16it/s, accuracy=75, loss=0.696]  


Epoch [28/100], Avg Loss: 0.8562, Accuracy: 77.78%, Time: 315.08s


Epoch [29/100]: 100%|██████████| 680/680 [05:14<00:00,  2.16it/s, accuracy=75, loss=0.708]  


Epoch [29/100], Avg Loss: 0.8561, Accuracy: 77.98%, Time: 314.46s


Epoch [30/100]: 100%|██████████| 680/680 [05:14<00:00,  2.16it/s, accuracy=100, loss=0.696] 


Epoch [30/100], Avg Loss: 0.8556, Accuracy: 78.19%, Time: 314.60s


Epoch [31/100]: 100%|██████████| 680/680 [05:14<00:00,  2.16it/s, accuracy=100, loss=0.696] 


Epoch [31/100], Avg Loss: 0.8591, Accuracy: 77.57%, Time: 314.43s


Epoch [32/100]: 100%|██████████| 680/680 [05:14<00:00,  2.16it/s, accuracy=75, loss=0.703]  


Epoch [32/100], Avg Loss: 0.8572, Accuracy: 77.99%, Time: 314.33s


Epoch [33/100]: 100%|██████████| 680/680 [05:14<00:00,  2.16it/s, accuracy=50, loss=0.695]  


Epoch [33/100], Avg Loss: 0.8579, Accuracy: 77.96%, Time: 314.16s


Epoch [34/100]: 100%|██████████| 680/680 [05:14<00:00,  2.16it/s, accuracy=50, loss=0.696]  


Epoch [34/100], Avg Loss: 0.8583, Accuracy: 77.72%, Time: 314.22s


Epoch [35/100]: 100%|██████████| 680/680 [05:14<00:00,  2.16it/s, accuracy=75, loss=0.695]  


Epoch [35/100], Avg Loss: 0.8712, Accuracy: 77.53%, Time: 314.14s


Epoch [36/100]: 100%|██████████| 680/680 [05:14<00:00,  2.17it/s, accuracy=25, loss=0.697]  


Epoch [36/100], Avg Loss: 0.8619, Accuracy: 77.38%, Time: 314.03s


Epoch [37/100]: 100%|██████████| 680/680 [05:14<00:00,  2.16it/s, accuracy=75, loss=0.697]  


Epoch [37/100], Avg Loss: 0.8652, Accuracy: 77.73%, Time: 314.49s


Epoch [38/100]: 100%|██████████| 680/680 [05:13<00:00,  2.17it/s, accuracy=50, loss=0.698]  


Epoch [38/100], Avg Loss: 0.8695, Accuracy: 77.48%, Time: 313.75s


Epoch [39/100]: 100%|██████████| 680/680 [05:13<00:00,  2.17it/s, accuracy=50, loss=0.705]  


Epoch [39/100], Avg Loss: 0.8643, Accuracy: 77.75%, Time: 313.86s


Epoch [40/100]: 100%|██████████| 680/680 [05:13<00:00,  2.17it/s, accuracy=75, loss=0.695]  


Epoch [40/100], Avg Loss: 0.8666, Accuracy: 77.64%, Time: 313.76s


Epoch [41/100]: 100%|██████████| 680/680 [05:13<00:00,  2.17it/s, accuracy=25, loss=0.694]  


Epoch [41/100], Avg Loss: 0.8691, Accuracy: 77.50%, Time: 313.76s


Epoch [42/100]: 100%|██████████| 680/680 [05:13<00:00,  2.17it/s, accuracy=100, loss=0.695] 


Epoch [42/100], Avg Loss: 0.8699, Accuracy: 77.65%, Time: 313.51s


Epoch [43/100]: 100%|██████████| 680/680 [05:13<00:00,  2.17it/s, accuracy=100, loss=0.694] 


Epoch [43/100], Avg Loss: 0.8699, Accuracy: 77.96%, Time: 313.70s


Epoch [44/100]: 100%|██████████| 680/680 [05:13<00:00,  2.17it/s, accuracy=50, loss=0.696]  


Epoch [44/100], Avg Loss: 0.8797, Accuracy: 77.33%, Time: 313.37s


Epoch [45/100]: 100%|██████████| 680/680 [05:13<00:00,  2.17it/s, accuracy=75, loss=0.696]  


Epoch [45/100], Avg Loss: 0.8789, Accuracy: 77.19%, Time: 313.70s


Epoch [46/100]: 100%|██████████| 680/680 [05:13<00:00,  2.17it/s, accuracy=50, loss=0.696]  


Epoch [46/100], Avg Loss: 0.8772, Accuracy: 77.44%, Time: 313.54s


Epoch [47/100]: 100%|██████████| 680/680 [05:13<00:00,  2.17it/s, accuracy=75, loss=0.698]  


Epoch [47/100], Avg Loss: 0.8761, Accuracy: 77.35%, Time: 313.24s


Epoch [48/100]: 100%|██████████| 680/680 [05:13<00:00,  2.17it/s, accuracy=50, loss=0.694]  


Epoch [48/100], Avg Loss: 0.9037, Accuracy: 76.60%, Time: 313.23s


Epoch [49/100]: 100%|██████████| 680/680 [05:13<00:00,  2.17it/s, accuracy=75, loss=0.695]  


Epoch [49/100], Avg Loss: 0.8830, Accuracy: 76.97%, Time: 313.42s


Epoch [50/100]: 100%|██████████| 680/680 [05:13<00:00,  2.17it/s, accuracy=50, loss=0.696]  


Epoch [50/100], Avg Loss: 0.8783, Accuracy: 77.18%, Time: 313.03s


Epoch [51/100]: 100%|██████████| 680/680 [05:13<00:00,  2.17it/s, accuracy=50, loss=0.694]  


Epoch [51/100], Avg Loss: 0.8606, Accuracy: 77.33%, Time: 313.26s


Epoch [52/100]: 100%|██████████| 680/680 [05:13<00:00,  2.17it/s, accuracy=100, loss=0.697] 


Epoch [52/100], Avg Loss: 0.8747, Accuracy: 77.57%, Time: 313.38s


Epoch [53/100]: 100%|██████████| 680/680 [05:13<00:00,  2.17it/s, accuracy=75, loss=0.694]  


Epoch [53/100], Avg Loss: 0.8584, Accuracy: 77.95%, Time: 313.58s


Epoch [54/100]: 100%|██████████| 680/680 [05:13<00:00,  2.17it/s, accuracy=50, loss=0.699]  


Epoch [54/100], Avg Loss: 0.8583, Accuracy: 77.90%, Time: 313.50s


Epoch [55/100]: 100%|██████████| 680/680 [05:13<00:00,  2.17it/s, accuracy=50, loss=0.695]  


Epoch [55/100], Avg Loss: 0.8602, Accuracy: 77.34%, Time: 313.30s


Epoch [56/100]: 100%|██████████| 680/680 [05:13<00:00,  2.17it/s, accuracy=75, loss=0.695]  


Epoch [56/100], Avg Loss: 0.8556, Accuracy: 77.76%, Time: 313.56s


Epoch [57/100]: 100%|██████████| 680/680 [05:13<00:00,  2.17it/s, accuracy=100, loss=0.697] 


Epoch [57/100], Avg Loss: 0.8642, Accuracy: 77.45%, Time: 313.87s


Epoch [58/100]: 100%|██████████| 680/680 [05:13<00:00,  2.17it/s, accuracy=75, loss=0.704]  


Epoch [58/100], Avg Loss: 0.8871, Accuracy: 77.16%, Time: 313.27s


Epoch [59/100]: 100%|██████████| 680/680 [05:12<00:00,  2.17it/s, accuracy=50, loss=0.696]  


Epoch [59/100], Avg Loss: 0.8727, Accuracy: 77.25%, Time: 312.80s


Epoch [60/100]: 100%|██████████| 680/680 [05:13<00:00,  2.17it/s, accuracy=75, loss=0.699]  


Epoch [60/100], Avg Loss: 0.8566, Accuracy: 77.72%, Time: 313.12s


Epoch [61/100]: 100%|██████████| 680/680 [05:13<00:00,  2.17it/s, accuracy=75, loss=0.696]  


Epoch [61/100], Avg Loss: 0.8589, Accuracy: 77.76%, Time: 313.28s


Epoch [62/100]: 100%|██████████| 680/680 [05:12<00:00,  2.17it/s, accuracy=50, loss=0.695]  


Epoch [62/100], Avg Loss: 0.8777, Accuracy: 76.90%, Time: 312.93s


Epoch [63/100]: 100%|██████████| 680/680 [05:12<00:00,  2.17it/s, accuracy=100, loss=0.696] 


Epoch [63/100], Avg Loss: 0.8996, Accuracy: 76.47%, Time: 312.78s


Epoch [64/100]: 100%|██████████| 680/680 [05:12<00:00,  2.17it/s, accuracy=50, loss=0.695]  


Epoch [64/100], Avg Loss: 0.8657, Accuracy: 77.22%, Time: 312.83s


Epoch [65/100]: 100%|██████████| 680/680 [05:12<00:00,  2.17it/s, accuracy=50, loss=0.694]  


Epoch [65/100], Avg Loss: 0.8690, Accuracy: 77.02%, Time: 312.74s


Epoch [66/100]: 100%|██████████| 680/680 [05:12<00:00,  2.17it/s, accuracy=100, loss=0.695] 


Epoch [66/100], Avg Loss: 0.8677, Accuracy: 77.06%, Time: 312.85s


Epoch [67/100]: 100%|██████████| 680/680 [05:12<00:00,  2.18it/s, accuracy=100, loss=0.695] 


Epoch [67/100], Avg Loss: 0.8674, Accuracy: 77.35%, Time: 312.42s


Epoch [68/100]: 100%|██████████| 680/680 [05:13<00:00,  2.17it/s, accuracy=100, loss=0.694] 


Epoch [68/100], Avg Loss: 0.8589, Accuracy: 77.59%, Time: 313.62s


Epoch [69/100]: 100%|██████████| 680/680 [05:12<00:00,  2.17it/s, accuracy=100, loss=0.699] 


Epoch [69/100], Avg Loss: 0.8613, Accuracy: 77.57%, Time: 312.79s


Epoch [70/100]: 100%|██████████| 680/680 [05:13<00:00,  2.17it/s, accuracy=100, loss=0.698] 


Epoch [70/100], Avg Loss: 0.8616, Accuracy: 77.59%, Time: 313.07s


Epoch [71/100]: 100%|██████████| 680/680 [05:12<00:00,  2.17it/s, accuracy=100, loss=0.696] 


Epoch [71/100], Avg Loss: 0.8655, Accuracy: 78.07%, Time: 312.78s


Epoch [72/100]: 100%|██████████| 680/680 [05:13<00:00,  2.17it/s, accuracy=75, loss=0.695]  


Epoch [72/100], Avg Loss: 0.8649, Accuracy: 77.89%, Time: 313.02s


Epoch [73/100]: 100%|██████████| 680/680 [05:12<00:00,  2.17it/s, accuracy=75, loss=0.695]  


Epoch [73/100], Avg Loss: 0.8595, Accuracy: 77.96%, Time: 312.89s


Epoch [74/100]: 100%|██████████| 680/680 [05:12<00:00,  2.17it/s, accuracy=75, loss=0.695]  


Epoch [74/100], Avg Loss: 0.8618, Accuracy: 77.56%, Time: 312.79s


Epoch [75/100]: 100%|██████████| 680/680 [05:12<00:00,  2.17it/s, accuracy=75, loss=0.701]  


Epoch [75/100], Avg Loss: 0.8594, Accuracy: 77.65%, Time: 312.90s


Epoch [76/100]: 100%|██████████| 680/680 [05:12<00:00,  2.17it/s, accuracy=75, loss=0.695]  


Epoch [76/100], Avg Loss: 0.8611, Accuracy: 77.81%, Time: 312.99s


Epoch [77/100]: 100%|██████████| 680/680 [05:12<00:00,  2.17it/s, accuracy=100, loss=0.694] 


Epoch [77/100], Avg Loss: 0.8727, Accuracy: 77.30%, Time: 312.88s


Epoch [78/100]: 100%|██████████| 680/680 [05:12<00:00,  2.17it/s, accuracy=100, loss=0.695] 


Epoch [78/100], Avg Loss: 0.8521, Accuracy: 77.85%, Time: 312.93s


Epoch [79/100]: 100%|██████████| 680/680 [05:13<00:00,  2.17it/s, accuracy=75, loss=0.696]  


Epoch [79/100], Avg Loss: 0.8526, Accuracy: 77.96%, Time: 313.28s


Epoch [80/100]: 100%|██████████| 680/680 [05:12<00:00,  2.17it/s, accuracy=100, loss=0.696] 


Epoch [80/100], Avg Loss: 0.8607, Accuracy: 77.44%, Time: 312.83s


Epoch [81/100]: 100%|██████████| 680/680 [05:12<00:00,  2.18it/s, accuracy=75, loss=0.696]  


Epoch [81/100], Avg Loss: 0.8672, Accuracy: 77.26%, Time: 312.45s


Epoch [82/100]: 100%|██████████| 680/680 [05:12<00:00,  2.18it/s, accuracy=50, loss=0.695]  


Epoch [82/100], Avg Loss: 0.8790, Accuracy: 77.34%, Time: 312.59s


Epoch [83/100]: 100%|██████████| 680/680 [05:12<00:00,  2.17it/s, accuracy=100, loss=0.7]   


Epoch [83/100], Avg Loss: 0.8750, Accuracy: 77.22%, Time: 312.90s


Epoch [84/100]: 100%|██████████| 680/680 [05:12<00:00,  2.18it/s, accuracy=100, loss=0.695] 


Epoch [84/100], Avg Loss: 0.8657, Accuracy: 77.76%, Time: 312.61s


Epoch [85/100]: 100%|██████████| 680/680 [05:12<00:00,  2.18it/s, accuracy=100, loss=0.699] 


Epoch [85/100], Avg Loss: 0.8617, Accuracy: 78.07%, Time: 312.58s


Epoch [86/100]: 100%|██████████| 680/680 [05:13<00:00,  2.17it/s, accuracy=75, loss=0.695]  


Epoch [86/100], Avg Loss: 0.8630, Accuracy: 77.70%, Time: 313.05s


Epoch [87/100]: 100%|██████████| 680/680 [05:12<00:00,  2.18it/s, accuracy=75, loss=0.694]  


Epoch [87/100], Avg Loss: 1.0582, Accuracy: 73.02%, Time: 312.28s


Epoch [88/100]: 100%|██████████| 680/680 [05:12<00:00,  2.17it/s, accuracy=75, loss=0.728]  


Epoch [88/100], Avg Loss: 0.9406, Accuracy: 75.60%, Time: 312.71s


Epoch [89/100]: 100%|██████████| 680/680 [05:12<00:00,  2.18it/s, accuracy=75, loss=0.699]  


Epoch [89/100], Avg Loss: 0.9289, Accuracy: 75.68%, Time: 312.25s


Epoch [90/100]: 100%|██████████| 680/680 [05:12<00:00,  2.17it/s, accuracy=100, loss=0.702] 


Epoch [90/100], Avg Loss: 0.9517, Accuracy: 75.36%, Time: 312.79s


Epoch [91/100]: 100%|██████████| 680/680 [05:12<00:00,  2.17it/s, accuracy=100, loss=0.696] 


Epoch [91/100], Avg Loss: 0.9388, Accuracy: 75.34%, Time: 312.72s


Epoch [92/100]: 100%|██████████| 680/680 [05:12<00:00,  2.18it/s, accuracy=75, loss=0.695]  


Epoch [92/100], Avg Loss: 0.9375, Accuracy: 75.42%, Time: 312.48s


Epoch [93/100]: 100%|██████████| 680/680 [05:12<00:00,  2.17it/s, accuracy=100, loss=0.696] 


Epoch [93/100], Avg Loss: 0.9468, Accuracy: 75.38%, Time: 312.79s


Epoch [94/100]: 100%|██████████| 680/680 [05:12<00:00,  2.17it/s, accuracy=75, loss=0.694]  


Epoch [94/100], Avg Loss: 0.9359, Accuracy: 75.28%, Time: 312.82s


Epoch [95/100]: 100%|██████████| 680/680 [05:12<00:00,  2.17it/s, accuracy=100, loss=0.694] 


Epoch [95/100], Avg Loss: 0.9209, Accuracy: 75.90%, Time: 312.75s


Epoch [96/100]: 100%|██████████| 680/680 [05:12<00:00,  2.17it/s, accuracy=100, loss=0.695] 


Epoch [96/100], Avg Loss: 0.9047, Accuracy: 77.07%, Time: 312.68s


Epoch [97/100]: 100%|██████████| 680/680 [05:12<00:00,  2.18it/s, accuracy=100, loss=0.72]  


Epoch [97/100], Avg Loss: 0.9094, Accuracy: 76.40%, Time: 312.23s


Epoch [98/100]: 100%|██████████| 680/680 [05:12<00:00,  2.18it/s, accuracy=75, loss=0.698]  


Epoch [98/100], Avg Loss: 0.9042, Accuracy: 76.90%, Time: 312.54s


Epoch [99/100]: 100%|██████████| 680/680 [05:12<00:00,  2.18it/s, accuracy=25, loss=0.702]  


Epoch [99/100], Avg Loss: 0.9159, Accuracy: 76.07%, Time: 312.44s


Epoch [100/100]: 100%|██████████| 680/680 [05:12<00:00,  2.18it/s, accuracy=75, loss=0.695]  

Epoch [100/100], Avg Loss: 0.9015, Accuracy: 76.91%, Time: 312.36s





In [21]:
# Saving the model in different formats after all epochs
def save_model(model):
    torch.save(model.state_dict(), "MOCO_model_weights_final.pth")
    torch.save(model, "MOCO_full_model_final.pth")
    with h5py.File("MOCO_model_weights_final.h5", 'w') as f:
        for name, param in model.named_parameters():
            f.create_dataset(name, data=param.cpu().detach().numpy())

# Save model after training completes
save_model(model_with_dropout)

print("model saved")

model saved
