In [1]:
import os
import random
import shutil
import argparse
import pandas as pd
import gc

def split_dataset(image_dir, output_dir, train_ratio, val_ratio, test_ratio):
    
    images = [img for img in os.listdir(image_dir)]
        
    

    # Adjust the extension
    total_images = len(images)
    random.shuffle(images)

    train_size = int(total_images * train_ratio)
    val_size = int(total_images * val_ratio)

    # Split the dataset
    train_images = images[:train_size]
    val_images = images[train_size:train_size + val_size]
    test_images = images[train_size + val_size:]

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    with open(os.path.join(output_dir, 'train.txt'), 'w') as f:
      for item in train_images:
        file_name_without_ext = item.lower().rsplit('.jpg', 1)[0]
        f.write("%s\n" % file_name_without_ext)
        #f.write("%s\n" % os.path.splitext(item)[0]) 

    with open(os.path.join(output_dir, 'val.txt'), 'w') as f:
      for item in val_images:
        file_name_without_ext = item.lower().rsplit('.jpg', 1)[0]
        f.write("%s\n" % file_name_without_ext)
        #f.write("%s\n" % os.path.splitext(item)[0]) 
        
    with open(os.path.join(output_dir, 'test.txt'), 'w') as f:
      for item in test_images:
        file_name_without_ext = item.lower().rsplit('.jpg', 1)[0]
        f.write("%s\n" % file_name_without_ext)
        #f.write("%s\n" % os.path.splitext(item)[0])  

    return train_images, val_images, test_images


In [2]:
split_dataset('/teamspace/studios/this_studio/discriminator_false_label/dataset/samples', '/teamspace/studios/this_studio/discriminator_false_label/dataset/splits', 0.75, 0.2,0.05)

(['eastafrica_oleacape_2020.06.02.13.04.00_529ffc53-5bb9-4024-8e0e-fb0cdb1cbd21_img_20200529_160846_1078423866.jpg',
  'haiti_acacauri_2021.05.31.10.44.04_18.28520064242184_-73.564688321203_f1b23d67-cfd2-4741-8975-715dc0f931f2_img_20210527_075726_1008487320000484652.jpg',
  'eastafrica_afzeliaa_2023.03.30.15.58.48_2.2114935166666667_31.47824771666667_077a9822-e16b-4556-ab19-6be7d90aba72_img_20230328_132513_8240023535835035087.jpg',
  'eastafrica_avicenni_2021.07.15.13.30.45_-5.125406691999997_39.11243287999997_d6587876-2bab-4ef3-b76d-55a4b682ed70_img_20210715_114220_323165884590609859.jpg',
  'eastafrica_avicenni_2021.07.15.13.30.19_-5.126656833999998_39.11121906200001_748228c3-63bd-46e2-b0b2-ed9709c4e34c_img_20210715_113417_3011454690660837856.jpg',
  'india_psidguaj (guava)2021.01.30.19.30.36_25.230063227936625_79.32936024852097_d68cf102-60d1-437e-ba5b-6570b1b3daf0_img_20210127_110234_1909062434403654501.jpg',
  'eastafrica_cordafri_2022.08.25.13.24.26_-3.3070423333333316_37.29833933

In [2]:
import torch.utils.data as data
class LeafDataset(data.Dataset):
    def __init__(self, root, image_set='train', img_transform=None, mask_transform=None):
        self.root = os.path.expanduser(root)
        self.img_transform = img_transform
        self.mask_transform = mask_transform
        self.image_set = image_set
        
        # Adjust paths
        image_dir = os.path.join(self.root, 'samples')
        mask_dir = os.path.join(self.root, 'binary_masks')
        label_dir = os.path.join(self.root, 'labels')
        split_fpath = os.path.join(self.root, 'splits', f'{self.image_set}.txt')        
        
        with open(split_fpath, 'r') as f:
            file_names = [x.strip() for x in f.readlines()]        

        self.images = [os.path.join(image_dir, fname + '.jpg') for fname in file_names]
        self.masks = [os.path.join(mask_dir, fname + '_binarymask.jpg') for fname in file_names]
        self.labels = [os.path.join(label_dir, fname + '_label.txt') for fname in file_names]
        
    def __getitem__(self, index):
        img_path = self.images[index]
        mask_path = self.masks[index]
        label_path = self.labels[index]

        img = Image.open(img_path).convert('RGB')
        mask = Image.open(mask_path)

        mask_array = np.array(mask)
        mask_array = (mask_array > 128).astype(np.uint8) # Binarize to 0s and 1s

        mask_array = mask_array * 255
        mask = Image.fromarray(mask_array.astype(np.uint8))
        
        if self.img_transform is not None:
          img = self.img_transform(img)
        
        if self.mask_transform is not None:
          mask = self.mask_transform(mask)

        #mask = torch.squeeze(mask, 0)
        if mask.shape[0] != 3:
            mask = mask.repeat(3, 1, 1)

        with open(label_path, 'r') as f:
            label = f.readline()
            label = int(label.strip()) 
            label = torch.tensor(1) if label == 1 else torch.tensor(0)

        return img, mask, label

    
    def __len__(self):
        return len(self.images)

In [3]:
from torchvision import transforms
train_img_transform = transforms.Compose([
          #RandomCropAndPad(512),
          transforms.Resize((512, 512)),
          #transforms.RandomResizedCrop(size=(256, 256)),
          #transforms.RandomHorizontalFlip(),
          #transforms.RandomRotation(degrees=(0, 360)),
          #transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
          transforms.ToTensor(),
          transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

train_mask_transform = transforms.Compose([
            #RandomCropAndPadMask(512),
            transforms.Resize((512, 512), interpolation=transforms.InterpolationMode.NEAREST),
            #transforms.RandomResizedCrop(size=(256, 256), interpolation=transforms.InterpolationMode.NEAREST),
            #transforms.RandomHorizontalFlip(),
            #transforms.RandomRotation(degrees=(0, 360)),
            transforms.ToTensor(),
        ])

In [4]:
train_dst = LeafDataset(root= '/teamspace/studios/this_studio/discriminator_false_label/dataset/', image_set='train', img_transform=train_img_transform, mask_transform=train_mask_transform)

In [5]:
train_loader = data.DataLoader(
        train_dst, batch_size= 4, shuffle=True, num_workers=1,
        drop_last=True) 

In [6]:
print(len(train_dst))

51


In [7]:
from PIL import Image
import numpy as np
import torch
for imgs, masks, labels in train_loader:
    print(imgs.shape, masks.shape, labels.shape)
    break

torch.Size([4, 3, 512, 512]) torch.Size([4, 3, 512, 512]) torch.Size([4])


In [36]:
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
from PIL import Image
# Define your discriminator network
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(6, 16, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)
        self.fc1 = nn.Linear(64 * 64 * 64, 256)
        self.fc2 = nn.Linear(256, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, img, mask):
        if len(mask.shape) == 3:
            mask = mask.unsqueeze(1)  # Add channel dimension if missing

        x = torch.cat((img, mask), dim=1)  # Concatenate image and mask along channel dimension
        x = F.leaky_relu(self.conv1(x), 0.2)
        x = F.leaky_relu(self.conv2(x), 0.2)
        x = F.leaky_relu(self.conv3(x), 0.2)
        x = x.view(x.size(0), -1)
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = self.fc2(x).squeeze(1)
        return self.sigmoid(x)

In [76]:
discriminator = Discriminator()
criterion = nn.BCELoss()
optimizer = optim.Adam(discriminator.parameters(), lr=0.25, betas=(0.5, 0.999))

def compute_loss(predictions, targets):
    # predictions: tensor of shape (batch_size, ...)
    # targets: tensor of shape (batch_size, ...), dtype Long

    # Convert targets to float
    targets = targets.float()

    # Compute binary cross entropy loss
    loss = F.binary_cross_entropy_with_logits(predictions, targets)

    return loss

In [10]:
print(os.getcwd())

/teamspace/studios/this_studio/discriminator_false_label


In [77]:
# Training loop
import torch
import numpy as np
num_epochs = 15
cur_itrs = 0
interval_loss = 0
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model = Discriminator().to(device)
for epoch in range(num_epochs):
    model.train()
    for imgs,masks,labels in train_loader:
        imgs, masks, labels = imgs.to(device), masks.to(device), labels.to(device)
        cur_itrs += 1

        # Forward pass
        outputs = model(imgs, masks)
        #print(labels)
        loss = compute_loss(outputs, labels)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        np_loss = loss.detach().cpu().numpy()
        #print(np_loss)
        interval_loss += np_loss
        del imgs, masks, labels, outputs, loss
        gc.collect()
        torch.cuda.empty_cache() 
        if (cur_itrs) % 10 == 0:
            print('interval loss ' + str(interval_loss))
            interval_loss = 0.0
    model_path = 'trial_1.pth'
    optimizer_path = 'optimizer_trial_1.pth'
    torch.save(model.state_dict(), model_path)
    torch.save(optimizer.state_dict(), optimizer_path)
    print('saved.')

cuda
interval loss 8.00613647699356
saved.
interval loss 7.119887351989746
saved.
interval loss 7.245870769023895
saved.
interval loss 7.372581660747528
saved.
interval loss 7.628607630729675
interval loss 7.499457478523254
saved.
interval loss 7.371996104717255
saved.
interval loss 7.879944443702698
saved.
interval loss 7.7511221170425415
saved.
interval loss 7.499353766441345
saved.
interval loss 7.120481729507446
interval loss 7.628564238548279
saved.
interval loss 7.626720368862152
saved.
interval loss 7.499725341796875
saved.
interval loss 7.625049352645874
saved.
interval loss 7.6290686428546906
saved.
interval loss 7.2455264031887054
interval loss 7.373652338981628
saved.


In [78]:
val_dst = LeafDataset(root= '/teamspace/studios/this_studio/discriminator_false_label/dataset/', image_set='val', img_transform=train_img_transform, mask_transform=train_mask_transform)

In [33]:
test_dst = LeafDataset(root= '/teamspace/studios/this_studio/discriminator_false_label/dataset/', image_set='test', img_transform=train_img_transform, mask_transform=train_mask_transform)

In [71]:
print(len(val_dst))

13


In [79]:
val_loader = data.DataLoader(
        val_dst, batch_size= 4, shuffle=True, num_workers=1,
        drop_last=True) 

In [80]:
model = Discriminator().to(device)
model.load_state_dict(torch.load('/teamspace/studios/this_studio/discriminator_false_label/trial_1.pth'))
model.eval()  # Set the model to evaluation mode

Discriminator(
  (conv1): Conv2d(6, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (conv2): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (conv3): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (fc1): Linear(in_features=262144, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)

In [81]:
    val_loss = 0
    all_outputs = []
    all_labels = []

In [82]:
    with torch.no_grad():
        for imgs, masks, labels in val_loader:
            imgs, masks, labels = imgs.to(device), masks.to(device), labels.to(device)
            outputs = model(imgs, masks)
            loss = compute_loss(outputs, labels)
            val_loss += loss.item()

            all_outputs.extend(outputs.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

            # Clear memory
            del imgs, masks, labels, outputs, loss
            torch.cuda.empty_cache()

    avg_val_loss = val_loss / len(val_loader)
    all_outputs = np.array(all_outputs)
    print(all_outputs)
    all_labels = np.array(all_labels)
    print(all_labels)

    # Compute additional metrics like accuracy, precision, recall, etc.
    predictions = (all_outputs > 0.5).astype(int)
    accuracy = np.mean(predictions == all_labels)
    
    print(f'Validation Loss: {avg_val_loss}')
    print(f'Validation Accuracy: {accuracy}')

[0.5107652  0.502795   0.5079358  0.5071701  0.5064978  0.5031692
 0.50639796 0.50610954 0.50940406 0.50652844 0.50366265 0.5112094 ]
[1 0 1 1 1 0 1 1 1 1 0 1]
Validation Loss: 0.5973167816797892
Validation Accuracy: 0.75


In [32]:
for imgs, masks, labels in val_loader:
    labels = labels.to(device)
    print(labels)

tensor([1, 0, 1, 0], device='cuda:0')
tensor([1, 1, 1, 1], device='cuda:0')
tensor([0, 1, 1, 1], device='cuda:0')


In [36]:
def validate_model(model, val_loader, device):
    model.eval()
    val_loss = 0
    all_outputs = []
    all_labels = []

    with torch.no_grad():
        for imgs, masks, labels in val_loader:
            imgs, masks, labels = imgs.to(device), masks.to(device), labels.to(device)
            outputs = model(imgs, masks)
            loss = compute_loss(outputs, labels)
            val_loss += loss.item()

            all_outputs.extend(outputs.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

            # Clear memory
            del imgs, masks, labels, outputs, loss
            torch.cuda.empty_cache()

    avg_val_loss = val_loss / len(val_loader)
    all_outputs = np.array(all_outputs)
    all_labels = np.array(all_labels)

    # Compute additional metrics like accuracy, precision, recall, etc.
    predictions = (all_outputs > 0.5).astype(int)
    accuracy = np.mean(predictions == all_labels)
    
    print(f'Validation Loss: {avg_val_loss}')
    print(f'Validation Accuracy: {accuracy}')

    return avg_val_loss, accuracy

# Assuming the rest of your setup is complete...
# Validate the model
val_loss, val_accuracy = validate_model(model, val_loader, device)


Validation Loss: 0.6988524198532104
Validation Accuracy: 0.25
