In [57]:
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torch
import torch.nn as nn
import torchvision.transforms as T



class CustomTIFDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.data = []

        # Traverse the directory and collect image paths and labels
        for label in os.listdir(root_dir):
            label_dir = os.path.join(root_dir, label)
            if os.path.isdir(label_dir):
                for random_name in os.listdir(label_dir):
                    random_name_dir = os.path.join(label_dir, random_name)
                    if os.path.isdir(random_name_dir):
                        images = [os.path.join(random_name_dir, f) for f in os.listdir(random_name_dir) if f.endswith('.tif')]
                        if len(images) == 12:  # Ensure there are exactly 12 images
                            self.data.append((images, label))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image_paths, label = self.data[idx]
        images = [Image.open(img_path) for img_path in image_paths]

        if self.transform:
            images = [self.transform(img) for img in images]

        # Stack images along the channel dimension if needed or process them as required
        images = torch.stack(images, dim=0)  # Assuming we want to stack along a new dimension

        return images, label

# Define any transforms if needed
transform = transforms.Compose([
    transforms.ToTensor(),
    # Add any other transformations here
])

# Custom collate function to resize images to a common size
def custom_collate_fn(batch):
    max_size = tuple(max(s) for s in zip(*[img.shape for img, _ in batch]))
    batch = [(T.Resize(max_size[1:])(img), target) for img, target in batch]
    images = torch.stack([img for img, _ in batch])
    targets = torch.stack([target for _, target in batch])
    return images, targets

# Create dataset and dataloader
dataset = CustomTIFDataset(root_dir=r'C:\Users\ADMIN\OneDrive\Desktop\VuonAI\AI FOR AGRICULTURE\Dataset\ICPR02\data demo', transform=transform)
dataloader = DataLoader(dataset, collate_fn=custom_collate_fn, batch_size=4, shuffle=True, num_workers=2)

<__main__.CustomTIFDataset object at 0x000001E453E5C5C0>


In [45]:
import torch
import torchvision.transforms as T

def regular_augmentation():
    return T.Compose([
        T.RandomRotation(degrees=45),
        T.RandomResizedCrop(size=120, scale=(0.8, 1.2)),
        T.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0)),
        T.Lambda(lambda x: x + 0.05 * torch.randn_like(x))  # Add random noise
    ])

In [46]:
class LocalGlobalAugmentation:
    def __init__(self, global_size=120, local_size=36):
        self.global_size = global_size
        self.local_size = local_size

    def __call__(self, x):
        # Assuming x is a tensor of shape (C, H, W)
        global_view = T.Resize(self.global_size)(x)
        
        # Create local views
        local_views = []
        for _ in range(4):  # Generating 4 local views
            i = torch.randint(0, x.size(1) - self.local_size, (1,))
            j = torch.randint(0, x.size(2) - self.local_size, (1,))
            local_view = T.Resize(self.local_size)(x[:, i:i+self.local_size, j:j+self.local_size])
            local_views.append(local_view)
        
        return global_view, local_views


In [47]:
class SpectralAwareAugmentation:
    def __init__(self, drop_rate=0.4):
        self.drop_rate = drop_rate

    def __call__(self, x):
        # Assuming x is a tensor of shape (C, H, W)
        num_channels = x.size(0)
        drop_mask = torch.rand(num_channels) > self.drop_rate
        return x * drop_mask.unsqueeze(1).unsqueeze(2)


In [48]:
class ObjSSLAugmentation:
    def __init__(self):
        self.regular_aug = regular_augmentation()
        self.lag_aug = LocalGlobalAugmentation()
        self.spectral_aug = SpectralAwareAugmentation()

    def __call__(self, x):
        # Apply regular augmentation
        x_aug = self.regular_aug(x)

        # Apply local and global augmentation
        global_view, local_views = self.lag_aug(x_aug)

        # Apply spectral-aware augmentation to the local views
        spectral_views = [self.spectral_aug(view) for view in local_views]

        return global_view, spectral_views


In [49]:
class ObjSSLModel(nn.Module):
    def __init__(self, student_network, teacher_network):
        super(ObjSSLModel, self).__init__()
        self.student_network = student_network
        self.teacher_network = teacher_network
        self.augment = ObjSSLAugmentation()

    def forward(self, x):
        global_view, local_views = self.augment(x)

        # Teacher network processes the global view
        with torch.no_grad():
            teacher_output = self.teacher_network(global_view)

        # Student network processes the local views
        student_outputs = [self.student_network(view) for view in local_views]

        return teacher_output, student_outputs

In [50]:
import torch.optim as optim
import torch.nn.functional as F
from torchvision.models import resnet18

# Define student and teacher networks
student_network = resnet18(pretrained=False)
teacher_network = resnet18(pretrained=False)

model = ObjSSLModel(student_network, teacher_network)
optimizer = optim.Adam(model.parameters(), lr=0.001)

def contrastive_loss(teacher_output, student_outputs):
    # Define a contrastive loss function
    # This is a placeholder, actual implementation may vary
    loss = 0
    for student_output in student_outputs:
        loss += F.mse_loss(student_output, teacher_output)
    return loss

num_epochs = 10


# Training loop
for epoch in range(num_epochs):
    for batch in dataloader:
        x = batch['image']

        teacher_output, student_outputs = model(x)

        loss = contrastive_loss(teacher_output, student_outputs)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')


RuntimeError: DataLoader worker (pid(s) 4472, 15244) exited unexpectedly