In [15]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [21]:
import os
path = '/content/drive/MyDrive/Dataset'
print(os.listdir(path))

['brinjal _cercospora leaf spot', 'naval_healthy', 'Lotus Rotting tubers', 'Brinjal Tobacco mosaic virus', 'Lotus nutrient deficiency  and rotting tubers', 'naval_anthracnose', 'naval_leaf_galls']


In [17]:
!pip install python-dotenv

Collecting python-dotenv
  Downloading python_dotenv-1.1.0-py3-none-any.whl.metadata (24 kB)
Downloading python_dotenv-1.1.0-py3-none-any.whl (20 kB)
Installing collected packages: python-dotenv
Successfully installed python-dotenv-1.1.0


In [22]:
from dotenv import load_dotenv
load_dotenv('/content/drive/MyDrive/SECRET.env')

True

In [23]:
import os, torch, shutil, numpy as np
from glob import glob
from PIL import Image
from torch.utils.data import random_split, Dataset, DataLoader
from torchvision import transforms as T

class CustomDataset(Dataset):
  def __init__(self, root, transformations=None):
    self.root = root
    self.transformations = transformations
    self.im_paths = sorted(glob(f"{root}/*/*"))
    self.class_names, self.class_counts, count = {}, {}, 0
    for idx, im_path in enumerate(self.im_paths):
      classname = self.getClassName(im_path)
      if classname not in self.class_names:
        self.class_names[classname] = count
        self.class_counts[classname] = 1
        count += 1
      else:
        self.class_counts[classname] += 1

  def getClassName(self, path):
      return os.path.dirname(path).split('/')[-1]

  def __len__(self):
      return len(self.im_paths)
  def __getitem__(self, idx):
     im_path = self.im_paths[idx]
     im = Image.open(im_path).convert("RGB")
     gt = self.class_names[self.getClassName(im_path)]

     if self.transformations:
      im = self.transformations(im)

     return im, gt


def get_dls(root, transformations, bs, split=[0.7, 0.15, 0.15], ns=2):
  ds = CustomDataset(root=root, transformations=transformations)

  tot_len = len(ds)
  tr_len = int(tot_len*split[0])
  vl_len = int(tot_len*split[1])
  ts_len = tot_len - tr_len - vl_len

  tr_ds, vl_ds, ts_ds = random_split(dataset=ds, lengths=[tr_len, vl_len, ts_len])
  tr_dl = DataLoader(tr_ds, batch_size=bs, shuffle=True, num_workers=ns)
  vl_dl = DataLoader(vl_ds, batch_size=bs, shuffle=False, num_workers=ns)
  ts_dl = DataLoader(ts_ds, batch_size=bs, shuffle=False, num_workers=ns)

  return tr_dl, vl_dl, ts_dl, ds.class_counts, ds.class_names, ds.im_paths

root = path
mean, std, im_size = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225], 224
tfs = T.Compose([
    T.Resize((im_size, im_size)),
    T.RandomHorizontalFlip(),
    # T.RandomVerticalFlip(),
    # T.RandomRotation(20),
    # T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    # T.RandomResizedCrop(im_size, scale=(0.8, 1.0)),
    T.ToTensor(),
    T.Normalize(mean, std)
])

tr_dl, vl_dl, ts_dl, class_counts, classes, im_paths = get_dls(root, tfs, 16)

print(len(tr_dl), len(vl_dl), len(ts_dl))
print(class_counts)
print(len(im_paths))

7 2 2
{'Brinjal Tobacco mosaic virus': 9, 'Lotus Rotting tubers': 46, 'Lotus nutrient deficiency  and rotting tubers': 15, 'brinjal _cercospora leaf spot': 9, 'naval_anthracnose': 29, 'naval_healthy': 8, 'naval_leaf_galls': 31}
147


In [None]:
!pip install timm torchmetrics torch torchvision torchaudio



In [24]:
from huggingface_hub import login
login(os.getenv('HUG_LOGIN_ID'))

In [None]:
import timm, torchmetrics
from tqdm import tqdm

m = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=len(classes))

def train_setup(m): return m.to("cuda").eval(), 15, "cuda", torch.nn.CrossEntropyLoss(), torch.optim.Adam(params = m.parameters(), lr = 1e-5)
def to_device(batch, device): return batch[0].to(device), batch[1].to(device)
def get_metrics(model, ims, gts, loss_fn, epoch_loss, epoch_acc, epoch_f1): preds = model(ims); loss = loss_fn(preds, gts); return loss, epoch_loss + (loss.item()), epoch_acc + (torch.argmax(preds, dim = 1) == gts).sum().item(), epoch_f1 + f1_score(preds, gts)

m, epochs, device, loss_fn, optimizer = train_setup(m)

f1_score = torchmetrics.F1Score(task = "multiclass", num_classes = len(classes)).to(device)
save_prefix, save_dir = "child_wound", "saved_models"
print("Start training...")
best_acc, best_loss, threshold, not_improved, patience = 0, float("inf"), 0.01, 0, 5
tr_losses, val_losses, tr_accs, val_accs, tr_f1s, val_f1s = [], [], [], [], [], []

best_loss = float(torch.inf)

for epoch in range(epochs):

    epoch_loss, epoch_acc, epoch_f1 = 0, 0, 0
    for idx, batch in tqdm(enumerate(tr_dl)):

        ims, gts = to_device(batch, device)

        loss, epoch_loss, epoch_acc, epoch_f1 = get_metrics(m, ims, gts, loss_fn, epoch_loss, epoch_acc, epoch_f1)
        optimizer.zero_grad(); loss.backward(); optimizer.step()

    tr_loss_to_track = epoch_loss / len(tr_dl)
    tr_acc_to_track  = epoch_acc  / len(tr_dl.dataset)
    tr_f1_to_track   = epoch_f1   / len(tr_dl)
    tr_losses.append(tr_loss_to_track); tr_accs.append(tr_acc_to_track); tr_f1s.append(tr_f1_to_track)

    print(f"{epoch + 1}-epoch train process is completed!")
    print(f"{epoch + 1}-epoch train loss          -> {tr_loss_to_track:.3f}")
    print(f"{epoch + 1}-epoch train accuracy      -> {tr_acc_to_track:.3f}")
    print(f"{epoch + 1}-epoch train f1-score      -> {tr_f1_to_track:.3f}")

    m.eval()
    with torch.no_grad():
        val_epoch_loss, val_epoch_acc, val_epoch_f1 = 0, 0, 0
        for idx, batch in enumerate(vl_dl):
            ims, gts = to_device(batch, device)
            loss, val_epoch_loss, val_epoch_acc, val_epoch_f1 = get_metrics(m, ims, gts, loss_fn, val_epoch_loss, val_epoch_acc, val_epoch_f1)

        val_loss_to_track = val_epoch_loss / len(vl_dl)
        val_acc_to_track  = val_epoch_acc  / len(vl_dl.dataset)
        val_f1_to_track   = val_epoch_f1   / len(vl_dl)
        val_losses.append(val_loss_to_track); val_accs.append(val_acc_to_track); val_f1s.append(val_f1_to_track)

        print(f"{epoch + 1}-epoch validation process is completed!")
        print(f"{epoch + 1}-epoch validation loss     -> {val_loss_to_track:.3f}")
        print(f"{epoch + 1}-epoch validation accuracy -> {val_acc_to_track:.3f}")
        print(f"{epoch + 1}-epoch validation f1-score -> {val_f1_to_track:.3f}")

        if val_loss_to_track < (best_loss + threshold):
            os.makedirs(save_dir, exist_ok = True)
            best_loss = val_loss_to_track
            torch.save(m.state_dict(), f"{save_dir}/{save_prefix}_best_model.pth")

        else:
            not_improved += 1
            print(f"Loss value did not decrease for {not_improved} epochs")
            if not_improved == patience:
                print(f"Stop training since loss value did not decrease for {patience} epochs.")
                break

Start training...


7it [00:04,  1.60it/s]

1-epoch train process is completed!
1-epoch train loss          -> 1.726
1-epoch train accuracy      -> 0.402
1-epoch train f1-score      -> 0.396





1-epoch validation process is completed!
1-epoch validation loss     -> 1.543
1-epoch validation accuracy -> 0.455
1-epoch validation f1-score -> 0.469


7it [00:04,  1.67it/s]

2-epoch train process is completed!
2-epoch train loss          -> 0.748
2-epoch train accuracy      -> 0.676
2-epoch train f1-score      -> 0.690





2-epoch validation process is completed!
2-epoch validation loss     -> 0.943
2-epoch validation accuracy -> 0.636
2-epoch validation f1-score -> 0.594


7it [00:04,  1.64it/s]

3-epoch train process is completed!
3-epoch train loss          -> 0.331
3-epoch train accuracy      -> 0.922
3-epoch train f1-score      -> 0.929





3-epoch validation process is completed!
3-epoch validation loss     -> 0.695
3-epoch validation accuracy -> 0.773
3-epoch validation f1-score -> 0.740


7it [00:04,  1.49it/s]

4-epoch train process is completed!
4-epoch train loss          -> 0.169
4-epoch train accuracy      -> 0.971
4-epoch train f1-score      -> 0.973





4-epoch validation process is completed!
4-epoch validation loss     -> 0.619
4-epoch validation accuracy -> 0.773
4-epoch validation f1-score -> 0.740


7it [00:04,  1.60it/s]

5-epoch train process is completed!
5-epoch train loss          -> 0.064
5-epoch train accuracy      -> 0.990
5-epoch train f1-score      -> 0.991





5-epoch validation process is completed!
5-epoch validation loss     -> 0.579
5-epoch validation accuracy -> 0.773
5-epoch validation f1-score -> 0.740


7it [00:04,  1.74it/s]

6-epoch train process is completed!
6-epoch train loss          -> 0.035
6-epoch train accuracy      -> 0.990
6-epoch train f1-score      -> 0.991





6-epoch validation process is completed!
6-epoch validation loss     -> 0.623
6-epoch validation accuracy -> 0.818
6-epoch validation f1-score -> 0.823
Loss value did not decrease for 1 epochs


7it [00:03,  1.76it/s]

7-epoch train process is completed!
7-epoch train loss          -> 0.032
7-epoch train accuracy      -> 0.990
7-epoch train f1-score      -> 0.991





7-epoch validation process is completed!
7-epoch validation loss     -> 0.801
7-epoch validation accuracy -> 0.818
7-epoch validation f1-score -> 0.823
Loss value did not decrease for 2 epochs


7it [00:04,  1.56it/s]

8-epoch train process is completed!
8-epoch train loss          -> 0.031
8-epoch train accuracy      -> 0.990
8-epoch train f1-score      -> 0.991





8-epoch validation process is completed!
8-epoch validation loss     -> 0.612
8-epoch validation accuracy -> 0.818
8-epoch validation f1-score -> 0.823
Loss value did not decrease for 3 epochs


7it [00:04,  1.71it/s]

9-epoch train process is completed!
9-epoch train loss          -> 0.014
9-epoch train accuracy      -> 1.000
9-epoch train f1-score      -> 1.000





9-epoch validation process is completed!
9-epoch validation loss     -> 0.587
9-epoch validation accuracy -> 0.773
9-epoch validation f1-score -> 0.740


7it [00:03,  1.76it/s]

10-epoch train process is completed!
10-epoch train loss          -> 0.039
10-epoch train accuracy      -> 0.980
10-epoch train f1-score      -> 0.982





10-epoch validation process is completed!
10-epoch validation loss     -> 0.511
10-epoch validation accuracy -> 0.818
10-epoch validation f1-score -> 0.823


7it [00:03,  1.76it/s]

11-epoch train process is completed!
11-epoch train loss          -> 0.022
11-epoch train accuracy      -> 0.990
11-epoch train f1-score      -> 0.991





11-epoch validation process is completed!
11-epoch validation loss     -> 0.492
11-epoch validation accuracy -> 0.818
11-epoch validation f1-score -> 0.823


7it [00:04,  1.57it/s]

12-epoch train process is completed!
12-epoch train loss          -> 0.017
12-epoch train accuracy      -> 0.980
12-epoch train f1-score      -> 0.982





12-epoch validation process is completed!
12-epoch validation loss     -> 0.555
12-epoch validation accuracy -> 0.818
12-epoch validation f1-score -> 0.823
Loss value did not decrease for 4 epochs


7it [00:04,  1.56it/s]

13-epoch train process is completed!
13-epoch train loss          -> 0.026
13-epoch train accuracy      -> 0.990
13-epoch train f1-score      -> 0.991





13-epoch validation process is completed!
13-epoch validation loss     -> 0.602
13-epoch validation accuracy -> 0.818
13-epoch validation f1-score -> 0.823
Loss value did not decrease for 5 epochs
Stop training since loss value did not decrease for 5 epochs.
