In [3]:
import os
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

In [4]:
# Paths
image_dir = r'Data_Set_Larch_Casebearer/Imagedata'
augmented_dir = r'Data_Set_Larch_Casebearer/Augmented_Images'

# Create directory to save augmented images if it doesn't exist
if not os.path.exists(augmented_dir):
    os.makedirs(augmented_dir)

In [None]:
# Define ImageNet normalization (mean and std)
imagenet_normalization = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

# Define preprocessing transformations: resize, to tensor, normalization
preprocess = transforms.Compose([
    transforms.Resize((224, 224)), 
    transforms.ToTensor(),          
    imagenet_normalization,         
])

# Define augmentation transformations
augmentations = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),   
    transforms.RandomRotation(20),           
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)  
])

# Function to unnormalize the image for display or saving
def unnormalize(tensor):
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    return tensor * std + mean

# Convert tensor to NumPy array for display or saving
def tensor_to_numpy(tensor):
    tensor = tensor.permute(1, 2, 0)  # Change dimensions from [C, H, W] to [H, W, C]
    tensor = tensor.detach().numpy()  # Convert to NumPy array
    tensor = np.clip(tensor, 0, 1)    # Clip the values between 0 and 1 for proper display
    return tensor

# Iterate over all images in the dataset
for img_name in os.listdir(image_dir):
    if img_name.lower().endswith(('.jpg', '.jpeg', '.png')):
        # Load the image
        img_path = os.path.join(image_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        
        # Apply preprocessing and augmentation
        image_tensor = preprocess(image)
        augmented_image_tensor = augmentations(image_tensor)
        
        # Unnormalize the augmented image for visualization or saving
        unnormalized_augmented_image = unnormalize(augmented_image_tensor)
        
        # Convert to NumPy array for displaying or saving
        augmented_image_np = tensor_to_numpy(unnormalized_augmented_image)
        
        # Convert to PIL Image and save the augmented image
        augmented_image_pil = Image.fromarray((augmented_image_np * 255).astype(np.uint8))
        augmented_img_name = f"augmented_{img_name}"
        augmented_img_path = os.path.join(augmented_dir, augmented_img_name)
        augmented_image_pil.save(augmented_img_path)
print("Image augmentation complete!")

In [1]:
plt.figure(figsize=(5, 5))
        plt.imshow(augmented_image_np)
        plt.title(f'Augmented Image: {img_name}')
        plt.axis('off')  # Hide axis for better display
        plt.show()

IndentationError: unexpected indent (2917902397.py, line 2)

In [37]:
import xml.etree.ElementTree as ET
import os

# Function to parse the annotation file and count tree categories
def count_tree_categories(annotation_file):
    # Initialize counts
    counts = {
        'Healthy (H)': 0,
        'Light Damage (LD)': 0,
        'High Damage (HD)': 0,
        'Other': 0,
        'Total tree':0,
        'Health Score':0,
        'Damaged or Stressed':0,
        'Label':""
    }
    
    # Parse the XML file
    tree = ET.parse(annotation_file)
    root = tree.getroot()
    
    # Iterate through the objects in the annotation
    
    for obj in root.iter('object'):
        treet= obj.find('tree')
        if treet is not None:
            tree_type=treet.text
        else:
            tree_type=None
        d= obj.find('damage')
        if d is not None:
            damage_status=d.text
        else:
            damage_status=None
        
        # Update counts based on damage status
        if damage_status == 'H':
            counts['Healthy (H)'] += 1
        elif damage_status == 'LD':
            counts['Light Damage (LD)'] += 1
        elif damage_status == 'HD':
            counts['High Damage (HD)'] += 1
        elif damage_status== 'other':
            counts['Other'] += 1
        if tree_type:
            counts['Total tree']+=1
    if counts['Total tree']:        
         counts['Health Score']=(counts['Healthy (H)']/counts['Total tree'])*100
         counts['Damaged or Stressed']=((counts['Light Damage (LD)']+counts['High Damage (HD)'])/counts['Total tree'])*100
    if counts['Health Score']>=60:
        counts['Label']='Healthy'
    else:
        counts['Label']='Damaged/Stressed'
        
    return counts

total_counts = {
    'Healthy (H)': 0,
    'Light Damage (LD)': 0,
    'High Damage (HD)': 0,
    'Other': 0
}

# Example usage
image_path = r"C:\Users\User\Desktop\RESEARCH\Data_Set_Larch_Casebearer\Imagedata"
annotation_path = r"C:\Users\User\Desktop\RESEARCH\Data_Set_Larch_Casebearer\Annotations"

# Iterate through images and their corresponding annotations
image_filenames = os.listdir(image_path)

for img_name in image_filenames:
    if img_name.lower().endswith(('.jpg', '.jpeg', '.png')):
        # Construct the corresponding annotation filename
        annotation_file = img_name.replace('.JPG', '.xml').replace('.jpeg', '.xml').replace('.png', '.xml')
        annotation_file_path = os.path.join(annotation_path, annotation_file)
        
        # Count the categories
        if os.path.exists(annotation_file_path):
            counts = count_tree_categories(annotation_file_path)
            print(f"Counts for {img_name}: {counts}")
            total_counts['Healthy (H)'] += counts['Healthy (H)']
            total_counts['Light Damage (LD)'] += counts['Light Damage (LD)']
            total_counts['High Damage (HD)'] += counts['High Damage (HD)']
            total_counts['Other'] += counts['Other']
        else:
            print(f"Annotation file not found for {img_name}: {annotation_file_path}")
# Print the total counts
print("Total counts across the dataset:")
print(f"Healthy (H): {total_counts['Healthy (H)']}")
print(f"Light Damage (LD): {total_counts['Light Damage (LD)']}")
print(f"High Damage (HD): {total_counts['High Damage (HD)']}")
print(f"Other: {total_counts['Other']}")


Counts for B01_0004.JPG: {'Healthy (H)': 0, 'Light Damage (LD)': 31, 'High Damage (HD)': 22, 'Other': 4, 'Total tree': 57, 'Health Score': 0.0, 'Damaged or Stressed': 92.98245614035088, 'Label': 'Damaged/Stressed'}
Counts for B01_0005.JPG: {'Healthy (H)': 0, 'Light Damage (LD)': 23, 'High Damage (HD)': 9, 'Other': 12, 'Total tree': 44, 'Health Score': 0.0, 'Damaged or Stressed': 72.72727272727273, 'Label': 'Damaged/Stressed'}
Counts for B01_0006.JPG: {'Healthy (H)': 1, 'Light Damage (LD)': 31, 'High Damage (HD)': 2, 'Other': 10, 'Total tree': 44, 'Health Score': 2.272727272727273, 'Damaged or Stressed': 75.0, 'Label': 'Damaged/Stressed'}
Counts for B01_0007.JPG: {'Healthy (H)': 0, 'Light Damage (LD)': 8, 'High Damage (HD)': 6, 'Other': 10, 'Total tree': 24, 'Health Score': 0.0, 'Damaged or Stressed': 58.333333333333336, 'Label': 'Damaged/Stressed'}
Counts for B01_0012.JPG: {'Healthy (H)': 0, 'Light Damage (LD)': 3, 'High Damage (HD)': 9, 'Other': 35, 'Total tree': 47, 'Health Score': 0

Counts for B02_0008.JPG: {'Healthy (H)': 0, 'Light Damage (LD)': 8, 'High Damage (HD)': 0, 'Other': 43, 'Total tree': 51, 'Health Score': 0.0, 'Damaged or Stressed': 15.686274509803921, 'Label': 'Damaged/Stressed'}
Counts for B02_0009.JPG: {'Healthy (H)': 0, 'Light Damage (LD)': 1, 'High Damage (HD)': 0, 'Other': 46, 'Total tree': 47, 'Health Score': 0.0, 'Damaged or Stressed': 2.127659574468085, 'Label': 'Damaged/Stressed'}
Counts for B02_0012.JPG: {'Healthy (H)': 0, 'Light Damage (LD)': 14, 'High Damage (HD)': 2, 'Other': 51, 'Total tree': 67, 'Health Score': 0.0, 'Damaged or Stressed': 23.88059701492537, 'Label': 'Damaged/Stressed'}
Counts for B02_0013.JPG: {'Healthy (H)': 0, 'Light Damage (LD)': 61, 'High Damage (HD)': 2, 'Other': 12, 'Total tree': 75, 'Health Score': 0.0, 'Damaged or Stressed': 84.0, 'Label': 'Damaged/Stressed'}
Counts for B02_0014.JPG: {'Healthy (H)': 8, 'Light Damage (LD)': 58, 'High Damage (HD)': 0, 'Other': 0, 'Total tree': 66, 'Health Score': 12.1212121212121

Counts for B02_0221.JPG: {'Healthy (H)': 0, 'Light Damage (LD)': 0, 'High Damage (HD)': 0, 'Other': 34, 'Total tree': 34, 'Health Score': 0.0, 'Damaged or Stressed': 0.0, 'Label': 'Damaged/Stressed'}
Counts for B02_0222.JPG: {'Healthy (H)': 0, 'Light Damage (LD)': 0, 'High Damage (HD)': 0, 'Other': 40, 'Total tree': 40, 'Health Score': 0.0, 'Damaged or Stressed': 0.0, 'Label': 'Damaged/Stressed'}
Counts for B02_0223.JPG: {'Healthy (H)': 0, 'Light Damage (LD)': 0, 'High Damage (HD)': 0, 'Other': 34, 'Total tree': 34, 'Health Score': 0.0, 'Damaged or Stressed': 0.0, 'Label': 'Damaged/Stressed'}
Counts for B02_0224.JPG: {'Healthy (H)': 0, 'Light Damage (LD)': 0, 'High Damage (HD)': 0, 'Other': 34, 'Total tree': 34, 'Health Score': 0.0, 'Damaged or Stressed': 0.0, 'Label': 'Damaged/Stressed'}
Counts for B02_0225.JPG: {'Healthy (H)': 0, 'Light Damage (LD)': 0, 'High Damage (HD)': 0, 'Other': 57, 'Total tree': 57, 'Health Score': 0.0, 'Damaged or Stressed': 0.0, 'Label': 'Damaged/Stressed'}


Counts for B04_0084.JPG: {'Healthy (H)': 0, 'Light Damage (LD)': 70, 'High Damage (HD)': 34, 'Other': 21, 'Total tree': 125, 'Health Score': 0.0, 'Damaged or Stressed': 83.2, 'Label': 'Damaged/Stressed'}
Counts for B04_0085.JPG: {'Healthy (H)': 0, 'Light Damage (LD)': 115, 'High Damage (HD)': 17, 'Other': 14, 'Total tree': 146, 'Health Score': 0.0, 'Damaged or Stressed': 90.41095890410958, 'Label': 'Damaged/Stressed'}
Counts for B04_0086.JPG: {'Healthy (H)': 0, 'Light Damage (LD)': 113, 'High Damage (HD)': 4, 'Other': 13, 'Total tree': 130, 'Health Score': 0.0, 'Damaged or Stressed': 90.0, 'Label': 'Damaged/Stressed'}
Counts for B04_0087.JPG: {'Healthy (H)': 20, 'Light Damage (LD)': 82, 'High Damage (HD)': 2, 'Other': 4, 'Total tree': 108, 'Health Score': 18.51851851851852, 'Damaged or Stressed': 77.77777777777779, 'Label': 'Damaged/Stressed'}
Counts for B04_0088.JPG: {'Healthy (H)': 33, 'Light Damage (LD)': 88, 'High Damage (HD)': 0, 'Other': 9, 'Total tree': 130, 'Health Score': 25.3

Counts for B04_0122.JPG: {'Healthy (H)': 3, 'Light Damage (LD)': 86, 'High Damage (HD)': 20, 'Other': 13, 'Total tree': 122, 'Health Score': 2.459016393442623, 'Damaged or Stressed': 86.88524590163934, 'Label': 'Damaged/Stressed'}
Counts for B04_0123.JPG: {'Healthy (H)': 1, 'Light Damage (LD)': 80, 'High Damage (HD)': 37, 'Other': 11, 'Total tree': 129, 'Health Score': 0.7751937984496124, 'Damaged or Stressed': 90.69767441860465, 'Label': 'Damaged/Stressed'}
Counts for B04_0124.JPG: {'Healthy (H)': 2, 'Light Damage (LD)': 82, 'High Damage (HD)': 33, 'Other': 8, 'Total tree': 125, 'Health Score': 1.6, 'Damaged or Stressed': 92.0, 'Label': 'Damaged/Stressed'}
Counts for B04_0125.JPG: {'Healthy (H)': 6, 'Light Damage (LD)': 94, 'High Damage (HD)': 13, 'Other': 13, 'Total tree': 126, 'Health Score': 4.761904761904762, 'Damaged or Stressed': 84.92063492063492, 'Label': 'Damaged/Stressed'}
Counts for B04_0126.JPG: {'Healthy (H)': 8, 'Light Damage (LD)': 77, 'High Damage (HD)': 18, 'Other': 1

Counts for B05_0066.JPG: {'Healthy (H)': 7, 'Light Damage (LD)': 92, 'High Damage (HD)': 1, 'Other': 7, 'Total tree': 107, 'Health Score': 6.5420560747663545, 'Damaged or Stressed': 86.91588785046729, 'Label': 'Damaged/Stressed'}
Counts for B05_0067.JPG: {'Healthy (H)': 0, 'Light Damage (LD)': 86, 'High Damage (HD)': 2, 'Other': 10, 'Total tree': 98, 'Health Score': 0.0, 'Damaged or Stressed': 89.79591836734694, 'Label': 'Damaged/Stressed'}
Counts for B05_0068.JPG: {'Healthy (H)': 1, 'Light Damage (LD)': 80, 'High Damage (HD)': 0, 'Other': 6, 'Total tree': 87, 'Health Score': 1.1494252873563218, 'Damaged or Stressed': 91.95402298850574, 'Label': 'Damaged/Stressed'}
Counts for B05_0069.JPG: {'Healthy (H)': 2, 'Light Damage (LD)': 74, 'High Damage (HD)': 1, 'Other': 3, 'Total tree': 80, 'Health Score': 2.5, 'Damaged or Stressed': 93.75, 'Label': 'Damaged/Stressed'}
Counts for B05_0070.JPG: {'Healthy (H)': 1, 'Light Damage (LD)': 68, 'High Damage (HD)': 4, 'Other': 4, 'Total tree': 77, 'H

Counts for B06_0024.JPG: {'Healthy (H)': 0, 'Light Damage (LD)': 0, 'High Damage (HD)': 0, 'Other': 97, 'Total tree': 97, 'Health Score': 0.0, 'Damaged or Stressed': 0.0, 'Label': 'Damaged/Stressed'}
Counts for B06_0025.JPG: {'Healthy (H)': 0, 'Light Damage (LD)': 0, 'High Damage (HD)': 0, 'Other': 105, 'Total tree': 105, 'Health Score': 0.0, 'Damaged or Stressed': 0.0, 'Label': 'Damaged/Stressed'}
Counts for B06_0026.JPG: {'Healthy (H)': 0, 'Light Damage (LD)': 0, 'High Damage (HD)': 0, 'Other': 104, 'Total tree': 104, 'Health Score': 0.0, 'Damaged or Stressed': 0.0, 'Label': 'Damaged/Stressed'}
Counts for B06_0027.JPG: {'Healthy (H)': 0, 'Light Damage (LD)': 0, 'High Damage (HD)': 0, 'Other': 97, 'Total tree': 97, 'Health Score': 0.0, 'Damaged or Stressed': 0.0, 'Label': 'Damaged/Stressed'}
Counts for B06_0028.JPG: {'Healthy (H)': 0, 'Light Damage (LD)': 0, 'High Damage (HD)': 0, 'Other': 84, 'Total tree': 84, 'Health Score': 0.0, 'Damaged or Stressed': 0.0, 'Label': 'Damaged/Stresse

Counts for B06_0115.JPG: {'Healthy (H)': 0, 'Light Damage (LD)': 0, 'High Damage (HD)': 0, 'Other': 68, 'Total tree': 68, 'Health Score': 0.0, 'Damaged or Stressed': 0.0, 'Label': 'Damaged/Stressed'}
Counts for B06_0116.JPG: {'Healthy (H)': 0, 'Light Damage (LD)': 0, 'High Damage (HD)': 0, 'Other': 91, 'Total tree': 91, 'Health Score': 0.0, 'Damaged or Stressed': 0.0, 'Label': 'Damaged/Stressed'}
Counts for B06_0117.JPG: {'Healthy (H)': 0, 'Light Damage (LD)': 0, 'High Damage (HD)': 0, 'Other': 98, 'Total tree': 98, 'Health Score': 0.0, 'Damaged or Stressed': 0.0, 'Label': 'Damaged/Stressed'}
Counts for B06_0118.JPG: {'Healthy (H)': 0, 'Light Damage (LD)': 0, 'High Damage (HD)': 0, 'Other': 84, 'Total tree': 84, 'Health Score': 0.0, 'Damaged or Stressed': 0.0, 'Label': 'Damaged/Stressed'}
Counts for B06_0119.JPG: {'Healthy (H)': 0, 'Light Damage (LD)': 0, 'High Damage (HD)': 0, 'Other': 89, 'Total tree': 89, 'Health Score': 0.0, 'Damaged or Stressed': 0.0, 'Label': 'Damaged/Stressed'}


Counts for B07_0157.JPG: {'Healthy (H)': 0, 'Light Damage (LD)': 0, 'High Damage (HD)': 0, 'Other': 53, 'Total tree': 53, 'Health Score': 0.0, 'Damaged or Stressed': 0.0, 'Label': 'Damaged/Stressed'}
Counts for B07_0158.JPG: {'Healthy (H)': 0, 'Light Damage (LD)': 0, 'High Damage (HD)': 0, 'Other': 51, 'Total tree': 51, 'Health Score': 0.0, 'Damaged or Stressed': 0.0, 'Label': 'Damaged/Stressed'}
Counts for B07_0159.JPG: {'Healthy (H)': 0, 'Light Damage (LD)': 0, 'High Damage (HD)': 0, 'Other': 64, 'Total tree': 64, 'Health Score': 0.0, 'Damaged or Stressed': 0.0, 'Label': 'Damaged/Stressed'}
Counts for B07_0160.JPG: {'Healthy (H)': 0, 'Light Damage (LD)': 0, 'High Damage (HD)': 0, 'Other': 76, 'Total tree': 76, 'Health Score': 0.0, 'Damaged or Stressed': 0.0, 'Label': 'Damaged/Stressed'}
Counts for B07_0161.JPG: {'Healthy (H)': 0, 'Light Damage (LD)': 0, 'High Damage (HD)': 0, 'Other': 66, 'Total tree': 66, 'Health Score': 0.0, 'Damaged or Stressed': 0.0, 'Label': 'Damaged/Stressed'}


Counts for B09_0097.JPG: {'Healthy (H)': 0, 'Light Damage (LD)': 0, 'High Damage (HD)': 0, 'Other': 128, 'Total tree': 128, 'Health Score': 0.0, 'Damaged or Stressed': 0.0, 'Label': 'Damaged/Stressed'}
Counts for B09_0098.JPG: {'Healthy (H)': 0, 'Light Damage (LD)': 0, 'High Damage (HD)': 0, 'Other': 116, 'Total tree': 116, 'Health Score': 0.0, 'Damaged or Stressed': 0.0, 'Label': 'Damaged/Stressed'}
Counts for B09_0099.JPG: {'Healthy (H)': 0, 'Light Damage (LD)': 0, 'High Damage (HD)': 0, 'Other': 109, 'Total tree': 109, 'Health Score': 0.0, 'Damaged or Stressed': 0.0, 'Label': 'Damaged/Stressed'}
Counts for B09_0100.JPG: {'Healthy (H)': 0, 'Light Damage (LD)': 0, 'High Damage (HD)': 0, 'Other': 72, 'Total tree': 72, 'Health Score': 0.0, 'Damaged or Stressed': 0.0, 'Label': 'Damaged/Stressed'}
Counts for B09_0101.JPG: {'Healthy (H)': 0, 'Light Damage (LD)': 0, 'High Damage (HD)': 0, 'Other': 49, 'Total tree': 49, 'Health Score': 0.0, 'Damaged or Stressed': 0.0, 'Label': 'Damaged/Stres

Counts for B10_0101.JPG: {'Healthy (H)': 0, 'Light Damage (LD)': 0, 'High Damage (HD)': 0, 'Other': 73, 'Total tree': 73, 'Health Score': 0.0, 'Damaged or Stressed': 0.0, 'Label': 'Damaged/Stressed'}
Counts for B10_0102.JPG: {'Healthy (H)': 0, 'Light Damage (LD)': 0, 'High Damage (HD)': 0, 'Other': 75, 'Total tree': 75, 'Health Score': 0.0, 'Damaged or Stressed': 0.0, 'Label': 'Damaged/Stressed'}
Counts for B10_0103.JPG: {'Healthy (H)': 0, 'Light Damage (LD)': 0, 'High Damage (HD)': 0, 'Other': 82, 'Total tree': 82, 'Health Score': 0.0, 'Damaged or Stressed': 0.0, 'Label': 'Damaged/Stressed'}
Counts for B10_0104.JPG: {'Healthy (H)': 0, 'Light Damage (LD)': 0, 'High Damage (HD)': 0, 'Other': 76, 'Total tree': 76, 'Health Score': 0.0, 'Damaged or Stressed': 0.0, 'Label': 'Damaged/Stressed'}
Counts for B10_0105.JPG: {'Healthy (H)': 0, 'Light Damage (LD)': 0, 'High Damage (HD)': 0, 'Other': 66, 'Total tree': 66, 'Health Score': 0.0, 'Damaged or Stressed': 0.0, 'Label': 'Damaged/Stressed'}


In [3]:
!pip install timm


Collecting timm
  Using cached timm-1.0.11-py3-none-any.whl (2.3 MB)
Collecting huggingface_hub
  Using cached huggingface_hub-0.26.0-py3-none-any.whl (447 kB)
Collecting fsspec>=2023.5.0
  Using cached fsspec-2024.9.0-py3-none-any.whl (179 kB)
Installing collected packages: fsspec, huggingface_hub, timm
  Attempting uninstall: fsspec
    Found existing installation: fsspec 2022.7.1
    Uninstalling fsspec-2022.7.1:
      Successfully uninstalled fsspec-2022.7.1
Successfully installed fsspec-2024.9.0 huggingface_hub-0.26.0 timm-1.0.11


In [4]:
import timm

In [6]:
import torch
import torch.nn as nn
model = timm.create_model('swin_base_patch4_window7_224', pretrained=True)

# Modify the final classification layer to match the number of health categories (e.g., 4)
num_classes = 4  # Assuming 4 classes: Healthy, Light Damage, High Damage, Other
model.head = nn.Linear(model.head.in_features, num_classes)

# Move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

In [8]:
import numpy as np
import xml.etree.ElementTree as ET

def create_segmentation_mask(annotation_file, image_size):
    mask = np.zeros(image_size, dtype=np.uint8)  # Initialize mask with zeros (default: 'Other')
    
    tree = ET.parse(annotation_file)
    root = tree.getroot()
    
    for obj in root.iter('object'):
        bndbox = obj.find('bndbox')
        xmin = int(bndbox.find('xmin').text)
        ymin = int(bndbox.find('ymin').text)
        xmax = int(bndbox.find('xmax').text)
        ymax = int(bndbox.find('ymax').text)

        # Get damage type from annotation
        damage_status = obj.find('damage').text
        
        # Map the damage status to a number (you can modify this mapping)
        if damage_status == 'H':
            label = 0  # Healthy
        elif damage_status == 'LD':
            label = 1  # Light Damage
        elif damage_status == 'HD':
            label = 2  # High Damage
        else:
            label = 3  # Other

        # Fill the bounding box area in the mask with the label
        mask[ymin:ymax, xmin:xmax] = label
    
    return mask


In [9]:
from torch.utils.data import Dataset

class ForestSegmentationDataset(Dataset):
    def __init__(self, image_dir, annotation_dir, transform=None):
        self.image_dir = image_dir
        self.annotation_dir = annotation_dir
        self.transform = transform
        self.image_filenames = os.listdir(image_dir)
    
    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        img_name = self.image_filenames[idx]
        img_path = os.path.join(self.image_dir, img_name)
        annotation_file = os.path.join(self.annotation_dir, img_name.replace('.JPG', '.xml'))
        
        # Load image
        image = Image.open(img_path).convert('RGB')
        
        # Create segmentation mask
        mask = create_segmentation_mask(annotation_file, image.size)

        # Apply transforms
        if self.transform:
            image = self.transform(image)
            mask = torch.tensor(mask, dtype=torch.long)  # Convert mask to tensor

        return image, mask


In [10]:
import torchvision.models.segmentation as models

# Load the segmentation model with a pre-trained Swin Transformer backbone
model = models.deeplabv3_resnet50(pretrained=True)  # You can also use Swin backbone

# Modify the classifier to output 4 classes (Healthy, Light Damage, High Damage, Other)
model.classifier[4] = nn.Conv2d(256, 4, kernel_size=(1, 1), stride=(1, 1))

# Move to GPU if available
model = model.to(device)


Downloading: "https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth" to C:\Users\User/.cache\torch\hub\checkpoints\deeplabv3_resnet50_coco-cd0a2569.pth
100%|███████████████████████████████████████████████████████████████████████████████| 161M/161M [00:46<00:00, 3.60MB/s]


In [12]:
# Loss function for segmentation (CrossEntropyLoss works well for multi-class segmentation)
criterion = nn.CrossEntropyLoss()

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop
epochs=15
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    
    for images, masks in train_loader:
        images = images.to(device)
        masks = masks.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(images)['out']  # DeepLabV3 returns a dict, so we get 'out'
        loss = criterion(outputs, masks)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader):.4f}")


NameError: name 'train_loader' is not defined

In [13]:
def calculate_health_metrics(predicted_mask):
    # Count the number of pixels for each class
    total_pixels = predicted_mask.size
    healthy_pixels = (predicted_mask == 0).sum()
    light_damage_pixels = (predicted_mask == 1).sum()
    high_damage_pixels = (predicted_mask == 2).sum()
    other_pixels = (predicted_mask == 3).sum()

    # Calculate percentages
    healthy_percentage = (healthy_pixels / total_pixels) * 100
    light_damage_percentage = (light_damage_pixels / total_pixels) * 100
    high_damage_percentage = (high_damage_pixels / total_pixels) * 100
    other_percentage = (other_pixels / total_pixels) * 100

    return {
        'Healthy %': healthy_percentage,
        'Light Damage %': light_damage_percentage,
        'High Damage %': high_damage_percentage,
        'Other %': other_percentage
    }

# Example usage
predicted_mask = model(image)['out'].argmax(1)  # Get the predicted class for each pixel
metrics = calculate_health_metrics(predicted_mask)
print(metrics)


NameError: name 'image' is not defined