In [2]:
!pip install opencv-python
!pip install numpy
!pip install matplotlib




[notice] A new release of pip is available: 24.3.1 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip





[notice] A new release of pip is available: 24.3.1 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip





[notice] A new release of pip is available: 24.3.1 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip


# Without Watershed

In [2]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import find_peaks

def calculate_projection_profile_and_crop_lines(image_path, output_dir):
    # Load the image in grayscale
    image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)

    # Check if image is loaded
    if image is None:
        print("Error: Unable to load image.")
        return

    # Binarize the image using Otsu's threshold
    _, binary_image = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)

    # Calculate the horizontal projection profile
    horizontal_projection = np.sum(binary_image, axis=1)

    # Normalize the profile for better thresholding
    threshold = 0.1 * np.max(horizontal_projection)

    # Apply watershed algorithm for the below task
    # Find row ranges for each line
    line_ranges = []
    is_in_line = False
    start_row = 0
    for row, value in enumerate(horizontal_projection):
        if value > threshold:
            if not is_in_line:
                start_row = row
                is_in_line = True
        else:
            if is_in_line:
                end_row = row
                line_ranges.append((start_row,end_row))
                is_in_line = False


    if(len(line_ranges)!=0):
      line_ranges[0] = (max(0,line_ranges[0][0]-10),line_ranges[0][1])
      line_ranges[-1] = (line_ranges[-1][0],min(image.shape[1],line_ranges[-1][1]+7))

    for i in range(1,len(line_ranges)):
        temp = (line_ranges[i-1][1] + line_ranges[i][0]) // 2
        line_ranges[i-1] = (line_ranges[i-1][0], temp)
        line_ranges[i] = (temp, line_ranges[i][1])

    # Crop and save each detected line
    cropped_lines = []
    for idx, (start, end) in enumerate(line_ranges):
        cropped_line = image[start:end, :]  # Crop the original grayscale image
        cropped_lines.append(cropped_line)
        # Save the cropped line as an image
        # print(cropped_line)
        output_path = f"{output_dir}/line_{idx + 1}.png"
        cv2.imwrite(output_path, cropped_line)
        print(f"Saved cropped line {idx + 1} to {output_path}")

    # Plot the projection profile and detected lines
    plt.figure(figsize=(10, 6))
    plt.subplot(2, 1, 1)
    plt.imshow(binary_image, cmap='gray')
    plt.title("Binarized Image")
    plt.axis('off')

    plt.subplot(2, 1, 2)
    plt.plot(horizontal_projection, range(binary_image.shape[0]), color='b')
    plt.gca().invert_yaxis()
    plt.title("Horizontal Projection Profile")
    plt.xlabel("Sum of Pixel Intensities")
    plt.ylabel("Row Index")

    # # Mark the detected line ranges
    # for start, end in line_ranges:
    #     plt.axhline(y=start, color='r', linestyle='--', xmin=0.05, xmax=0.95, label="Line Start")
    #     plt.axhline(y=end, color='g', linestyle='--', xmin=0.05, xmax=0.95, label="Line End")
    # plt.legend(["Projection Profile", "Line Start", "Line End"], loc="upper right")
    # plt.show()

    # Print detected line ranges
    print("Detected Line Ranges (Start Row, End Row):")
    for start, end in line_ranges:
        print(f"Line: {start} to {end}")

# Example usage
# Replace the paths below with your image path and output directory
# calculate_projection_profile_and_crop_lines(image_path, output_dir)


# Watershed

In [6]:
import cv2
import numpy as np
import os
import matplotlib.pyplot as plt

def calculate_lines_watershed(image_path, output_dir):
    # 1) Read image in grayscale
    image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    if image is None:
        print("Error: Unable to load image.")
        return

    # 2) Binarize (Otsu) and invert so text is white, background is black
    #    (Watershed often expects the foreground to be bright on dark background)
    _, binary_inv = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)

    # 3) Morphological dilation to connect text within the same line
    #    Increase the kernel width if your lines are spaced out more
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (80, 1))
    dilated = cv2.dilate(binary_inv, kernel, iterations=1)

    # 4) Distance transform to help separate regions
    dist_transform = cv2.distanceTransform(dilated, cv2.DIST_L2, 5)
    dist_max = dist_transform.max()

    # 5) Threshold the distance map to get sure foreground
    #    Adjust the factor (e.g., 0.3 * dist_max) if lines are merging or splitting incorrectly
    _, sure_fg = cv2.threshold(dist_transform, 0.3 * dist_max, 255, 0)
    sure_fg = np.uint8(sure_fg)

    # 6) Unknown region is what's left between sure foreground and original binarized area
    #    (the watershed will decide how to label these)
    unknown = cv2.subtract(dilated, sure_fg)

    # 7) Label the sure foreground
    num_markers, markers = cv2.connectedComponents(sure_fg)
    # Make sure the background is labeled as 1 instead of 0
    markers = markers + 1
    # Label the unknown region as 0 so that watershed can fill it
    markers[unknown == 255] = 0

    # 8) Convert original image to color so we can visualize watershed boundaries
    image_color = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)

    # 9) Apply the watershed algorithm
    markers = cv2.watershed(image_color, markers)
    # Mark watershed boundaries in red on the color image (for visualization)
    image_color[markers == -1] = [0, 0, 255]

    # 10) Each distinct marker > 1 now corresponds to a text “line” region
    #     We’ll find bounding boxes for each marker and save them
    if not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)

    line_id = 1
    for marker_value in range(2, num_markers + 1):
        # Extract pixels belonging to the current marker
        mask = np.uint8(markers == marker_value)

        # If the mask is too small, skip it (noise or spurious region)
        if cv2.countNonZero(mask) < 10:
            continue

        # Find bounding box of the current marker
        x, y, w, h = cv2.boundingRect(mask)

        # Crop that region from the original grayscale image
        cropped_line = image[y:y+h, x:x+w]

        # Save the cropped line
        output_path = os.path.join(output_dir, f"line_{line_id}.png")
        cv2.imwrite(output_path, cropped_line)
        print(f"Saved cropped line {line_id} to {output_path}")
        line_id += 1

    # 11) (Optional) Show final results side by side
    #     Left: Original binarized (inverted) image
    #     Right: Watershed boundaries in red
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.title("Binarized (Inverted)")
    plt.imshow(binary_inv, cmap='gray')
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.title("Watershed Boundaries")
    plt.imshow(image_color[..., ::-1])  # Convert BGR to RGB for matplotlib
    plt.axis('off')
    plt.show()

    print("Watershed-based line segmentation complete!")

# Example usage:
# calculate_lines_watershed(image_path, output_dir)


# With watershed and horizontal projection profile

In [19]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import find_peaks

def calculate_projection_profile_and_crop_lines_with_watershed(image_path, output_dir):
    # Load the image in grayscale
    image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)

    # Check if image is loaded
    if image is None:
        print("Error: Unable to load image.")
        return

    # Binarize the image using Otsu's threshold
    _, binary_image = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)

    # Calculate the horizontal projection profile
    horizontal_projection = np.sum(binary_image, axis=1)

    # Normalize the profile for better thresholding
    threshold = 0.1 * np.max(horizontal_projection)

        # Convert horizontal projection to 2D grayscale image
    proj_img = horizontal_projection.astype(np.uint8)
    proj_img = cv2.normalize(proj_img, None, 0, 255, cv2.NORM_MINMAX)
    proj_img = 255 - proj_img  # Invert so lines become basins
    proj_img_2d = np.repeat(proj_img[:, np.newaxis], 100, axis=1)  # make it 2D image-like

    # Threshold to get sure foreground (lines)
    _, thresh = cv2.threshold(proj_img_2d, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)

    # Noise removal
    kernel = np.ones((3, 3), np.uint8)
    opening = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel, iterations=2)

    # Sure background area
    sure_bg = cv2.dilate(opening, kernel, iterations=3)

    # Sure foreground area
    dist_transform = cv2.distanceTransform(opening, cv2.DIST_L2, 5)
    _, sure_fg = cv2.threshold(dist_transform, 0.5 * dist_transform.max(), 255, 0)

    # Unknown region
    sure_fg = np.uint8(sure_fg)
    unknown = cv2.subtract(sure_bg, sure_fg)

    # Marker labelling
    num_labels, markers = cv2.connectedComponents(sure_fg)
    markers = markers + 1
    markers[unknown == 255] = 0

    # Apply watershed
    proj_img_color = cv2.cvtColor(proj_img_2d, cv2.COLOR_GRAY2BGR)
    cv2.watershed(proj_img_color, markers)

    # Extract line ranges from marker labels
    line_ranges = []
    for label in range(2, np.max(markers) + 1):  # skip background
        rows = np.where(markers[:, 50] == label)[0]  # Check the center column
        if rows.size > 0:
            start_row, end_row = np.min(rows), np.max(rows)
            if end_row - start_row > 5:  # filter small noise
                line_ranges.append((start_row, end_row))


        # Visualization
    plt.figure(figsize=(14, 10))

    # 1. Original Binarized Image
    plt.subplot(2, 2, 1)
    plt.imshow(binary_image, cmap='gray')
    plt.title("Binarized Image")
    plt.axis('off')

    # 2. Horizontal Projection Profile
    plt.subplot(2, 2, 2)
    plt.plot(horizontal_projection, range(binary_image.shape[0]), color='blue')
    plt.gca().invert_yaxis()
    plt.title("Horizontal Projection Profile")
    plt.xlabel("Sum of Pixel Intensities")
    plt.ylabel("Row Index")

    # 3. 2D Image from Projection
    plt.subplot(2, 2, 3)
    plt.imshow(proj_img_2d, cmap='gray', aspect='auto')
    plt.title("Projection Image for Watershed")
    plt.axis('off')

    # 4. Watershed Markers Visualized
    marker_vis = np.zeros_like(proj_img_2d)
    for label in range(2, np.max(markers) + 1):
        marker_vis[markers == label] = 255
    plt.subplot(2, 2, 4)
    plt.imshow(marker_vis, cmap='jet', aspect='auto')
    plt.title("Watershed Markers on Projection Image")
    plt.axis('off')

    plt.tight_layout()
    plt.show()



    if(len(line_ranges)!=0):
      line_ranges[0] = (max(0,line_ranges[0][0]-10),line_ranges[0][1])
      line_ranges[-1] = (line_ranges[-1][0],min(image.shape[1],line_ranges[-1][1]+7))

    for i in range(1,len(line_ranges)):
        temp = (line_ranges[i-1][1] + line_ranges[i][0]) // 2
        line_ranges[i-1] = (line_ranges[i-1][0], temp)
        line_ranges[i] = (temp, line_ranges[i][1])

    # Crop and save each detected line
    cropped_lines = []
    for idx, (start, end) in enumerate(line_ranges):
        cropped_line = image[start:end, :]  # Crop the original grayscale image
        cropped_lines.append(cropped_line)
        # Save the cropped line as an image
        # print(cropped_line)
        output_path = f"{output_dir}/line_{idx + 1}.png"
        cv2.imwrite(output_path, cropped_line)
        print(f"Saved cropped line {idx + 1} to {output_path}")

    # Plot the projection profile and detected lines
    plt.figure(figsize=(10, 6))
    plt.subplot(2, 1, 1)
    plt.imshow(binary_image, cmap='gray')
    plt.title("Binarized Image")
    plt.axis('off')

    plt.subplot(2, 1, 2)
    plt.plot(horizontal_projection, range(binary_image.shape[0]), color='b')
    plt.gca().invert_yaxis()
    plt.title("Horizontal Projection Profile")
    plt.xlabel("Sum of Pixel Intensities")
    plt.ylabel("Row Index")

    # # Mark the detected line ranges
    # for start, end in line_ranges:
    #     plt.axhline(y=start, color='r', linestyle='--', xmin=0.05, xmax=0.95, label="Line Start")
    #     plt.axhline(y=end, color='g', linestyle='--', xmin=0.05, xmax=0.95, label="Line End")
    # plt.legend(["Projection Profile", "Line Start", "Line End"], loc="upper right")
    # plt.show()

    # Print detected line ranges
    print("Detected Line Ranges (Start Row, End Row):")
    for start, end in line_ranges:
        print(f"Line: {start} to {end}")

# Example usage
# Replace the paths below with your image path and output directory
# calculate_projection_profile_and_crop_lines(image_path, output_dir)


# Testing

starting from 10,000 is IAM dataset

In [None]:
image_path = 'D://Thesis//Mass_Line_Extraction//input_images//'
output_dir = 'D://Thesis//Mass_Line_Extraction//output_images//'
output_dir_without = 'D://Thesis//Mass_Line_Extraction//output_images_without//'
output_dir_with_water_horizon = 'D://Thesis//Mass_Line_Extraction//output_with_water_horizon//'

Line_Extraction_code

In [19]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import find_peaks

def testing_calculate_projection_profile_and_crop_lines(image_path, output_dir, cnt):
    try:
        # Validate cnt
        if not isinstance(cnt, int):
            raise ValueError("Counter 'cnt' must be an integer.")

        # Validate output_dir, create if it doesn't exist
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
            print(f"Output directory '{output_dir}' did not exist and was created.")

        # Load the image in grayscale
        image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)

        # Check if image is loaded
        if image is None:
            raise FileNotFoundError(f"Error: Unable to load image from {image_path}")

        # Binarize the image using Otsu's threshold
        _, binary_image = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)

        # Calculate the horizontal projection profile
        horizontal_projection = np.sum(binary_image, axis=1)

        # Normalize the profile for better thresholding
        threshold = 0.1 * np.max(horizontal_projection)

        # Find row ranges for each line
        line_ranges = []
        is_in_line = False
        start_row = 0
        for row, value in enumerate(horizontal_projection):
            if value > threshold:
                if not is_in_line:
                    start_row = row
                    is_in_line = True
            else:
                if is_in_line:
                    end_row = row
                    line_ranges.append((start_row, end_row))
                    is_in_line = False

        # Adjust first and last line boundaries slightly
        if len(line_ranges) != 0:
            line_ranges[0] = (max(0, line_ranges[0][0] - 10), line_ranges[0][1])
            line_ranges[-1] = (line_ranges[-1][0], min(image.shape[0], line_ranges[-1][1] + 7))

        # Refine line boundaries by averaging gaps
        for i in range(1, len(line_ranges)):
            temp = (line_ranges[i - 1][1] + line_ranges[i][0]) // 2
            line_ranges[i - 1] = (line_ranges[i - 1][0], temp)
            line_ranges[i] = (temp, line_ranges[i][1])

        # Crop and save each detected line
        cropped_lines = []
        for idx, (start, end) in enumerate(line_ranges):
            cropped_line = image[start:end, :]  # Crop the original grayscale image
            cropped_lines.append(cropped_line)
            output_path = os.path.join(output_dir, f"line_{cnt}.png")

            success = cv2.imwrite(output_path, cropped_line)
            if not success:
                print(f"Warning: Failed to save cropped line {cnt} to {output_path}")
            else:
                print(f"Saved cropped line {cnt} to {output_path}")
            cnt += 1

        return cnt

    except Exception as e:
        print(f"An error occurred: {e}")
        return cnt  # Return current cnt even if something went wrong


extracting lines from custom dataset

In [None]:
import os

image_path = 'D://Thesis//Mass_Line_Extraction//input_images//'
output_dir_without = 'D://Thesis//Mass_Line_Extraction//output_images_without//'
# List all files in the directory
files = os.listdir(image_path)
cnt = 0
# Loop through the files and print the full path
for file in files:
    full_path = os.path.join(image_path, file)
    # cv2.imread(full_path)
    cnt = testing_calculate_projection_profile_and_crop_lines(full_path, output_dir_without, cnt)

Saved cropped line 0 to D://Thesis//Mass_Line_Extraction//output_images_without//line_0.png
Saved cropped line 1 to D://Thesis//Mass_Line_Extraction//output_images_without//line_1.png
Saved cropped line 2 to D://Thesis//Mass_Line_Extraction//output_images_without//line_2.png
Saved cropped line 3 to D://Thesis//Mass_Line_Extraction//output_images_without//line_3.png
Saved cropped line 4 to D://Thesis//Mass_Line_Extraction//output_images_without//line_4.png
Saved cropped line 5 to D://Thesis//Mass_Line_Extraction//output_images_without//line_5.png
Saved cropped line 6 to D://Thesis//Mass_Line_Extraction//output_images_without//line_6.png
Saved cropped line 7 to D://Thesis//Mass_Line_Extraction//output_images_without//line_7.png
Saved cropped line 8 to D://Thesis//Mass_Line_Extraction//output_images_without//line_8.png
Saved cropped line 9 to D://Thesis//Mass_Line_Extraction//output_images_without//line_9.png
Saved cropped line 10 to D://Thesis//Mass_Line_Extraction//output_images_without

extracting dataset from IAM dataset

In [21]:
import os

#For printed
# image_path = 'D://Thesis//Mass_Line_Extraction//output_images//printed'
# output_dir_without = 'D://Thesis//Mass_Line_Extraction//output_images_data//Printed'

#For Handwriting
image_path = 'D://Thesis//Mass_Line_Extraction//output_images//handwritten'
output_dir_without = 'D://Thesis//Mass_Line_Extraction//output_images_data//Handwritten'

# List all files in the directory
files = os.listdir(image_path)
cnt = 10000
# Loop through the files and print the full path
for file in files:
    full_path = os.path.join(image_path, file)
    # cv2.imread(full_path)
    cnt = testing_calculate_projection_profile_and_crop_lines(full_path, output_dir_without, cnt)

Saved cropped line 10000 to D://Thesis//Mass_Line_Extraction//output_images_data//Handwritten\line_10000.png
Saved cropped line 10001 to D://Thesis//Mass_Line_Extraction//output_images_data//Handwritten\line_10001.png
Saved cropped line 10002 to D://Thesis//Mass_Line_Extraction//output_images_data//Handwritten\line_10002.png
Saved cropped line 10003 to D://Thesis//Mass_Line_Extraction//output_images_data//Handwritten\line_10003.png
Saved cropped line 10004 to D://Thesis//Mass_Line_Extraction//output_images_data//Handwritten\line_10004.png
Saved cropped line 10005 to D://Thesis//Mass_Line_Extraction//output_images_data//Handwritten\line_10005.png
Saved cropped line 10006 to D://Thesis//Mass_Line_Extraction//output_images_data//Handwritten\line_10006.png
Saved cropped line 10007 to D://Thesis//Mass_Line_Extraction//output_images_data//Handwritten\line_10007.png
Saved cropped line 10008 to D://Thesis//Mass_Line_Extraction//output_images_data//Handwritten\line_10008.png
Saved cropped line 

# Seperating handwriting and printed portions from IAM dataset

In [16]:
import cv2


def seperate_image_section_cropping_handwritten_printed_IAM(image_path, handwritten_dir, printed_dir, cnt):
    # Load the image
    image = cv2.imread(image_path)

    # Get image dimensions
    height, width, _ = image.shape
    handwritten_dir = os.path.join(handwritten_dir, f"line_{cnt}.png")
    printed_dir = os.path.join(printed_dir, f"line_{cnt}.png")
    # Crop first half (top half) for printed text
    printed_crop = image[360:620, :]
    cv2.imwrite(handwritten_dir, printed_crop)

    # Crop second half (bottom half) for handwritten text
    handwritten_crop = image[700: height-900, :]
    cv2.imwrite(printed_dir, handwritten_crop)

    print("Image successfully split and saved.")

In [17]:
import os

folder_path = 'D://Thesis//dataset//Datasets//IAM//data//000//'
handwritten_dir = 'D://Thesis//Mass_Line_Extraction//output_images//printed'
printed_dir = 'D://Thesis//Mass_Line_Extraction//output_images//handwritten'
# List all files in the directory
files = os.listdir(folder_path)
cnt = 10000
# Loop through the files and print the full path
for file in files:
    full_path = os.path.join(folder_path, file)
    seperate_image_section_cropping_handwritten_printed_IAM(full_path, handwritten_dir, printed_dir, cnt)
    cnt = cnt + 1
    

Image successfully split and saved.
Image successfully split and saved.
Image successfully split and saved.
Image successfully split and saved.
Image successfully split and saved.
Image successfully split and saved.
Image successfully split and saved.
Image successfully split and saved.
Image successfully split and saved.
Image successfully split and saved.
Image successfully split and saved.
Image successfully split and saved.
Image successfully split and saved.
Image successfully split and saved.
Image successfully split and saved.
Image successfully split and saved.
Image successfully split and saved.
Image successfully split and saved.
Image successfully split and saved.
Image successfully split and saved.
Image successfully split and saved.
Image successfully split and saved.
Image successfully split and saved.
Image successfully split and saved.
Image successfully split and saved.
Image successfully split and saved.
Image successfully split and saved.
Image successfully split and

# Training the Classifier Handwriting vs printed

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms, models
from PIL import Image
import matplotlib.pyplot as plt

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Define custom dataset
class LineDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []
        
        # Label: printed -> 0, handwritten -> 1
        for label, folder in enumerate(['printed', 'handwritten']):
            folder_path = os.path.join(root_dir, folder)
            for img_name in os.listdir(folder_path):
                img_path = os.path.join(folder_path, img_name)
                self.samples.append((img_path, label))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label

# Image transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),  
    transforms.ToTensor(),          
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  
])

# Create full dataset
full_dataset = LineDataset(root_dir='path_to_your_dataset_folder', transform=transform)

# Split into training and validation (80% train, 20% val)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Define the model
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.model = models.resnet18(pretrained=True)
        self.model.fc = nn.Linear(self.model.fc.in_features, 2)  # 2 classes

    def forward(self, x):
        return self.model(x)

model = SimpleCNN().to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Track losses and accuracies
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []

# Training loop
epochs = 10
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    
    train_loss = running_loss / len(train_loader)
    train_acc = 100. * correct / total
    train_losses.append(train_loss)
    train_accuracies.append(train_acc)

    # Validation
    model.eval()
    val_running_loss = 0.0
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            val_running_loss += loss.item()

            _, predicted = outputs.max(1)
            val_total += labels.size(0)
            val_correct += predicted.eq(labels).sum().item()

    val_loss = val_running_loss / len(val_loader)
    val_acc = 100. * val_correct / val_total
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)

    print(f"Epoch [{epoch+1}/{epochs}] Train Loss: {train_loss:.4f} Train Acc: {train_acc:.2f}% | Val Loss: {val_loss:.4f} Val Acc: {val_acc:.2f}%")

# Save model
torch.save(model.state_dict(), 'line_classifier.pth')
print("Training complete and model saved!")

# --------------------------
# Plotting training and validation curves
# --------------------------
epochs_range = range(1, epochs + 1)

plt.figure(figsize=(14, 6))

# Loss
plt.subplot(1, 2, 1)
plt.plot(epochs_range, train_losses, marker='o', label='Train Loss', color='red')
plt.plot(epochs_range, val_losses, marker='o', label='Val Loss', color='orange')
plt.title('Loss Curve')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

# Accuracy
plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_accuracies, marker='o', label='Train Accuracy', color='blue')
plt.plot(epochs_range, val_accuracies, marker='o', label='Val Accuracy', color='green')
plt.title('Accuracy Curve')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()