<a href="https://colab.research.google.com/github/deepw98/project2/blob/main/project2_M.1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!cp -r /content/drive/MyDrive/fire_detection_few_shot /content/fire_detection_few_shot

In [3]:
import os
import numpy as np
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import random
import torch


In [4]:
class FewShotFireDataset(Dataset):
    def __init__(self, data_dir, transform=None, augment_transform=None):
        """
        Args:
            data_dir (str): Path to the dataset folder.
            transform (torchvision.transforms): Transformations for validation/testing.
            augment_transform (torchvision.transforms): Transformations for training (augmentation).
        """
        self.data_dir = data_dir
        self.transform = transform
        self.augment_transform = augment_transform
        self.class_to_images = self._load_images_by_class()

    def _load_images_by_class(self):
        class_to_images = {}
        for class_name in os.listdir(self.data_dir):
            class_path = os.path.join(self.data_dir, class_name)
            if os.path.isdir(class_path):
                class_to_images[class_name] = [
                    os.path.join(class_path, img) for img in os.listdir(class_path)
                ]
        return class_to_images

    def __len__(self):
        return len(self.class_to_images.keys())

    def __getitem__(self, index):
        class_name = list(self.class_to_images.keys())[index]
        images = self.class_to_images[class_name]

        # Apply augmentations to the images if augmentation transforms are provided
        transformed_images = [
            self.augment_transform(Image.open(img).convert("RGB")) if self.augment_transform else self.transform(Image.open(img).convert("RGB"))
            for img in images
        ]

        return class_name, transformed_images


In [5]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to match model input
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # EfficientNet normalization
])

augment_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    # Other existing transformations
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


In [6]:
from torchvision.transforms import (
    RandomHorizontalFlip,
    RandomRotation,
    ColorJitter,
    RandomResizedCrop,
    GaussianBlur,
    RandomErasing,
    ToTensor,
    Normalize,
    Compose
)

augment_transform = Compose([
    RandomResizedCrop(224, scale=(0.8, 1.0)),  # Random crop with scale adjustments
    RandomHorizontalFlip(p=0.5),               # Horizontal flip
    RandomRotation(15),                        # Rotate by ±15 degrees
    ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),  # Adjust color
    GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0)),  # Apply Gaussian blur
    RandomErasing(p=0.3, scale=(0.02, 0.2)),   # Random erasing
    ToTensor(),                                # Convert to tensor
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize
])


In [7]:
# Initialize training dataset with augmentation
train_dataset = FewShotFireDataset(
    data_dir='/content/fire_detection_few_shot/train',
    augment_transform=augment_transform  # For training
)

# Initialize testing dataset without augmentation
test_dataset = FewShotFireDataset(
    data_dir='/content/fire_detection_few_shot/test',
    transform=transform
)


In [8]:
class FewShotSampler:
    def __init__(self, dataset, n_way=2, k_shot=5, q_query=5):
        """
        Args:
            dataset (FewShotFireDataset): The dataset to sample from.
            n_way (int): Number of classes per episode.
            k_shot (int): Number of support samples per class.
            q_query (int): Number of query samples per class.
        """
        self.dataset = dataset
        self.n_way = n_way
        self.k_shot = k_shot
        self.q_query = q_query

    def sample_episode(self):
      # Convert dict_keys to list to ensure compatibility with random.sample
      class_list = list(self.dataset.class_to_images.keys())

      # Randomly select N classes
      selected_classes = random.sample(class_list, self.n_way)

      support_images, support_labels, query_images, query_labels = [], [], [], []

      # Create support and query sets
      label_map = {class_name: i for i, class_name in enumerate(selected_classes)}
      for class_name in selected_classes:
          images = self.dataset.class_to_images[class_name]
          sampled_images = random.sample(images, self.k_shot + self.q_query)

          support_images += sampled_images[:self.k_shot]
          query_images += sampled_images[self.k_shot:]

          # Labels
          support_labels += [label_map[class_name]] * self.k_shot
          query_labels += [label_map[class_name]] * self.q_query

      return support_images, support_labels, query_images, query_labels


In [9]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to match model input
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # EfficientNet normalization
])


In [10]:
def collate_fn(batch):
    """
    Custom collate function to preprocess support and query sets in each batch.
    """
    support_images, support_labels, query_images, query_labels = [], [], [], []
    for support, support_lbl, query, query_lbl in batch:
        support_images += support
        support_labels += support_lbl
        query_images += query
        query_labels += query_lbl

    # Apply transforms and convert to tensors
    support_images = torch.stack([transform(Image.open(img).convert("RGB")) for img in support_images])
    query_images = torch.stack([transform(Image.open(img).convert("RGB")) for img in query_images])
    support_labels = torch.tensor(support_labels)
    query_labels = torch.tensor(query_labels)

    return support_images, support_labels, query_images, query_labels

# Initialize Dataset and Sampler
dataset = FewShotFireDataset(data_dir='/content/fire_detection_few_shot/train', transform=transform)
sampler = FewShotSampler(dataset, n_way=2, k_shot=5, q_query=5)

# Wrap the sampler into a DataLoader
def episodic_loader(sampler, num_episodes):
    for _ in range(num_episodes):
        yield sampler.sample_episode()

dataloader = DataLoader(episodic_loader(sampler, num_episodes=100), batch_size=1, collate_fn=collate_fn)


In [11]:
class EpisodicDataset:
    def __init__(self, sampler, num_episodes):
        """
        Args:
            sampler (FewShotSampler): The sampler that generates episodes.
            num_episodes (int): Number of episodes for training or testing.
        """
        self.sampler = sampler
        self.num_episodes = num_episodes

    def __len__(self):
        return self.num_episodes

    def __getitem__(self, index):
        """
        Generates one episode using the sampler.
        """
        return self.sampler.sample_episode()


In [12]:
# Initialize Dataset and Sampler
episodic_dataset = EpisodicDataset(sampler, num_episodes=100)

# Wrap into a DataLoader
dataloader = DataLoader(episodic_dataset, batch_size=1, collate_fn=collate_fn)


In [13]:
# Initialize Dataset and Sampler
train_sampler = FewShotSampler(train_dataset, n_way=2, k_shot=5, q_query=5)
test_sampler = FewShotSampler(test_dataset, n_way=2, k_shot=5, q_query=5)

# Wrap into DataLoaders
train_loader = DataLoader(EpisodicDataset(train_sampler, num_episodes=100), batch_size=1, collate_fn=collate_fn)
test_loader = DataLoader(EpisodicDataset(test_sampler, num_episodes=10), batch_size=1, collate_fn=collate_fn)

# Training loop
for support_images, support_labels, query_images, query_labels in train_loader:
    print("Support Images Shape:", support_images.shape)
    print("Query Images Shape:", query_images.shape)
    break


Support Images Shape: torch.Size([10, 3, 224, 224])
Query Images Shape: torch.Size([10, 3, 224, 224])


In [14]:
for support_images, support_labels, query_images, query_labels in dataloader:
    print("Support Images Shape:", support_images.shape)
    print("Support Labels Shape:", support_labels.shape)
    print("Query Images Shape:", query_images.shape)
    print("Query Labels Shape:", query_labels.shape)
    break


Support Images Shape: torch.Size([10, 3, 224, 224])
Support Labels Shape: torch.Size([10])
Query Images Shape: torch.Size([10, 3, 224, 224])
Query Labels Shape: torch.Size([10])


In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import efficientnet_b0


In [16]:
class EmbeddingNetwork(nn.Module):
    def __init__(self):
        super(EmbeddingNetwork, self).__init__()
        # Load pre-trained EfficientNet-B0 and remove the classifier
        self.base_model = efficientnet_b0(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(self.base_model.children())[:-1])

    def forward(self, x):
        # Extract features and flatten
        x = self.feature_extractor(x).squeeze(-1).squeeze(-1)
        return x


In [17]:
import torch
import torch.nn as nn
from torchvision import models

class MatchingNetworkWithPretrainedEmbedding(nn.Module):
    def __init__(self, embedding_dim, n_way, k_shot):
        super(MatchingNetworkWithPretrainedEmbedding, self).__init__()
        # Use a pretrained ResNet backbone as the embedding network
        pretrained_resnet = models.resnet50(pretrained=True)  # Load ResNet50 with pretrained weights


        self.embedding_net = nn.Sequential(*list(pretrained_resnet.children())[:-1])  # Remove FC layer

        # Linear layer to project to the desired embedding dimension
        self.projector = nn.Linear(pretrained_resnet.fc.in_features, embedding_dim)

        # Softmax over cosine similarities
        self.softmax = nn.Softmax(dim=1)

    def forward(self, support_images, support_labels, query_images):
        # Extract embeddings for support and query images
        support_embeddings = self.embedding_net(support_images).squeeze()
        query_embeddings = self.embedding_net(query_images).squeeze()

        # Project embeddings to desired dimension
        support_embeddings = self.projector(support_embeddings)
        query_embeddings = self.projector(query_embeddings)

        # Normalize embeddings for cosine similarity
        support_embeddings = nn.functional.normalize(support_embeddings, p=2, dim=1)
        query_embeddings = nn.functional.normalize(query_embeddings, p=2, dim=1)

        # Calculate cosine similarity between support and query embeddings
        similarities = torch.matmul(query_embeddings, support_embeddings.T)

        # Softmax over similarities
        attention_weights = self.softmax(similarities)

        # Aggregate support labels weighted by attention
        support_labels_one_hot = nn.functional.one_hot(support_labels, num_classes=n_way).float()
        query_predictions = torch.matmul(attention_weights, support_labels_one_hot)

        return query_predictions


In [18]:
# # Initialize Matching Network
# embedding_dim = 1280  # EfficientNet-B0 output feature dimension
# n_way = 2
# k_shot = 5
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = MatchingNetwork(embedding_dim, n_way, k_shot).to(device)
embedding_dim = 128  # Define your desired embedding dimension
n_way = 2            # Number of classes
k_shot = 5           # Number of examples per class in support set
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MatchingNetworkWithPretrainedEmbedding(embedding_dim, n_way, k_shot).to(device)
for param in model.embedding_net.parameters():
          param.requires_grad = False

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


# Training loop
num_epochs = 20

for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for batch_idx, (support_images, support_labels, query_images, query_labels) in enumerate(dataloader):
        # Move data to GPU
        support_images, support_labels = support_images.to(device), support_labels.to(device)
        query_images, query_labels = query_images.to(device), query_labels.to(device)

        # Forward pass
        class_probabilities = model(support_images, support_labels, query_images)

        # Compute loss
        loss = criterion(class_probabilities, query_labels)
        total_loss += loss.item()

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss / len(dataloader):.4f}")


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 167MB/s]


Epoch [1/20], Loss: 0.4920
Epoch [2/20], Loss: 0.4224
Epoch [3/20], Loss: 0.4057
Epoch [4/20], Loss: 0.3994
Epoch [5/20], Loss: 0.3958
Epoch [6/20], Loss: 0.3927
Epoch [7/20], Loss: 0.3943
Epoch [8/20], Loss: 0.3910
Epoch [9/20], Loss: 0.3893
Epoch [10/20], Loss: 0.3877
Epoch [11/20], Loss: 0.3906
Epoch [12/20], Loss: 0.3883
Epoch [13/20], Loss: 0.3879
Epoch [14/20], Loss: 0.3875
Epoch [15/20], Loss: 0.3898
Epoch [16/20], Loss: 0.3864
Epoch [17/20], Loss: 0.3875
Epoch [18/20], Loss: 0.3875
Epoch [19/20], Loss: 0.3881
Epoch [20/20], Loss: 0.3860


In [19]:
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for support_images, support_labels, query_images, query_labels in dataloader:
        # Move data to GPU
        support_images, support_labels = support_images.cuda(), support_labels.cuda()
        query_images, query_labels = query_images.cuda(), query_labels.cuda()

        # Forward pass
        class_probabilities = model(support_images, support_labels, query_images)

        # Predicted classes
        predicted_classes = torch.argmax(class_probabilities, dim=1)

        # Compute accuracy
        correct += (predicted_classes == query_labels).sum().item()
        total += query_labels.size(0)

print(f"Accuracy: {100 * correct / total:.2f}%")


Accuracy: 100.00%


In [20]:
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
import torch

# Initialize variables to store ground truth and predictions
all_query_labels = []
all_predicted_classes = []

# Evaluation loop
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for support_images, support_labels, query_images, query_labels in dataloader:
        # Move data to GPU
        support_images, support_labels = support_images.cuda(), support_labels.cuda()
        query_images, query_labels = query_images.cuda(), query_labels.cuda()

        # Forward pass
        class_probabilities = model(support_images, support_labels, query_images)

        # Predicted classes
        predicted_classes = torch.argmax(class_probabilities, dim=1)

        # Collect predictions and labels for metric computation
        all_query_labels.extend(query_labels.cpu().numpy())
        all_predicted_classes.extend(predicted_classes.cpu().numpy())

        # Compute accuracy
        correct += (predicted_classes == query_labels).sum().item()
        total += query_labels.size(0)

# Compute accuracy
accuracy = 100 * correct / total
print(f"Accuracy: {accuracy:.2f}%")

# Compute precision, recall, F1 score, and confusion matrix
precision = precision_score(all_query_labels, all_predicted_classes, average="weighted")
recall = recall_score(all_query_labels, all_predicted_classes, average="weighted")
f1 = f1_score(all_query_labels, all_predicted_classes, average="weighted")
conf_matrix = confusion_matrix(all_query_labels, all_predicted_classes)

print(f"Precision: {precision:.2f}")
print(f"Recall: {recall:.2f}")
print(f"F1 Score: {f1:.2f}")
print("Confusion Matrix:")
print(conf_matrix)


Accuracy: 100.00%
Precision: 1.00
Recall: 1.00
F1 Score: 1.00
Confusion Matrix:
[[500   0]
 [  0 500]]


In [None]:
correct = 0
total = 0
episodes = 100  # Number of evaluation episodes

for _ in range(episodes):
    support_images, support_labels, query_images, query_labels = next(iter(dataloader))
    support_images, support_labels = support_images.to(device), support_labels.to(device)
    query_images, query_labels = query_images.to(device), query_labels.to(device)

    class_probabilities = model(support_images, support_labels, query_images)
    predicted_classes = torch.argmax(class_probabilities, dim=1)

    correct += (predicted_classes == query_labels).sum().item()
    total += query_labels.size(0)

accuracy = 100 * correct / total
print(f"Average Accuracy over {episodes} episodes: {accuracy:.2f}%")


Average Accuracy over 100 episodes: 99.60%


In [None]:
import os
from torchvision import transforms
from PIL import Image

def prepare_dataset(data_dir):
    """
    Prepare image data and labels from a directory.
    Args:
        data_dir (str): Path to the dataset directory.
    Returns:
        data (list): List of image file paths.
        labels (list): Corresponding class labels.
    """
    data = []
    labels = []
    class_to_idx = {}

    # Assign an integer index to each class
    for idx, class_name in enumerate(sorted(os.listdir(data_dir))):
        class_path = os.path.join(data_dir, class_name)
        if os.path.isdir(class_path):
            class_to_idx[class_name] = idx

            # Collect image file paths and labels
            for img_name in os.listdir(class_path):
                img_path = os.path.join(class_path, img_name)
                if img_name.endswith(('.png', '.jpg', '.jpeg')):  # Filter image files
                    data.append(img_path)
                    labels.append(idx)

    return data, labels, class_to_idx

# Example Usage
data_dir = "/content/fire_detection_few_shot/train"  # Path to the root dataset directory
data1, labels1, class_to_idx = prepare_dataset(data_dir)

# Output some stats
print(f"Total images: {len(data1)}")
print(f"Class-to-index mapping: {class_to_idx}")

# Example: Creating a CustomDataset instance
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])




Total images: 294
Class-to-index mapping: {'Fire': 0, 'No_Fire': 1}


In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

# Custom Dataset for data and labels
class CustomDataset(Dataset):
    def __init__(self, data, labels, transform=None):
        """
        Args:
            data (list): List of image file paths or tensors.
            labels (list): Corresponding labels.
            transform (callable, optional): Transform to apply to the data.
        """
        self.data = data
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img = self.data[idx]

        # If `img` is a file path, open it as an image
        if isinstance(img, str):  # Handle file paths
            img = Image.open(img).convert("RGB")

        if self.transform:
            img = self.transform(img)

        label = self.labels[idx]
        return img, label

# Example: Assuming you have `data` and `labels`
# Example data (replace with actual

# Define transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Create Dataset and DataLoader
dataset = CustomDataset(data1, labels1, transform=transform)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

# Iterate through DataLoader
for batch_idx, (images, labels) in enumerate(dataloader):
    print(f"Batch {batch_idx+1}:")
    print(f"Images shape: {images.shape}")
    print(f"Labels: {labels}")


Batch 1:
Images shape: torch.Size([8, 3, 224, 224])
Labels: tensor([0, 0, 0, 1, 1, 1, 1, 1])
Batch 2:
Images shape: torch.Size([8, 3, 224, 224])
Labels: tensor([0, 0, 0, 0, 1, 1, 1, 1])
Batch 3:
Images shape: torch.Size([8, 3, 224, 224])
Labels: tensor([1, 0, 1, 0, 1, 0, 0, 1])
Batch 4:
Images shape: torch.Size([8, 3, 224, 224])
Labels: tensor([1, 1, 0, 1, 0, 0, 1, 0])
Batch 5:
Images shape: torch.Size([8, 3, 224, 224])
Labels: tensor([0, 0, 0, 0, 0, 0, 1, 1])
Batch 6:
Images shape: torch.Size([8, 3, 224, 224])
Labels: tensor([1, 1, 0, 1, 1, 0, 0, 1])
Batch 7:
Images shape: torch.Size([8, 3, 224, 224])
Labels: tensor([0, 1, 0, 1, 0, 0, 1, 0])
Batch 8:
Images shape: torch.Size([8, 3, 224, 224])
Labels: tensor([0, 0, 0, 0, 1, 1, 1, 0])
Batch 9:
Images shape: torch.Size([8, 3, 224, 224])
Labels: tensor([1, 1, 1, 0, 1, 0, 0, 1])
Batch 10:
Images shape: torch.Size([8, 3, 224, 224])
Labels: tensor([1, 1, 1, 1, 0, 0, 1, 0])
Batch 11:
Images shape: torch.Size([8, 3, 224, 224])
Labels: tensor([

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet50
from torch.utils.data import DataLoader, Dataset
import random

# ResNet Backbone
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        resnet = resnet50(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(resnet.children())[:-1])  # Remove final classification layer
        self.fc = nn.Linear(2048, 256)  # Project to 256-dim space

    def forward(self, x):
        x = self.feature_extractor(x)
        x = x.view(x.size(0), -1)  # Flatten features
        x = self.fc(x)
        return x

# Triplet Loss
class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        pos_dist = torch.nn.functional.pairwise_distance(anchor, positive, p=2)
        neg_dist = torch.nn.functional.pairwise_distance(anchor, negative, p=2)
        loss = torch.clamp(pos_dist - neg_dist + self.margin, min=0.0).mean()
        return loss

# Hard Triplet Mining
def hard_triplet_mining(embeddings, labels):
    """
    Perform hard triplet mining within a batch.
    Args:
        embeddings (torch.Tensor): Embeddings of the batch (shape: [batch_size, embedding_dim]).
        labels (torch.Tensor): Corresponding labels of the batch (shape: [batch_size]).
    Returns:
        anchors, positives, negatives: Hard triplets selected from the batch.
    """
    batch_size = embeddings.size(0)
    distance_matrix = torch.cdist(embeddings, embeddings, p=2)  # Pairwise distances

    anchors, positives, negatives = [], [], []

    for i in range(batch_size):
        positive_mask = labels == labels[i]
        negative_mask = labels != labels[i]

        # Hardest positive: max distance among positives
        hardest_positive = torch.argmax(distance_matrix[i][positive_mask]).item()
        # Hardest negative: min distance among negatives
        hardest_negative = torch.argmin(distance_matrix[i][negative_mask]).item()

        anchors.append(embeddings[i])
        positives.append(embeddings[hardest_positive])
        negatives.append(embeddings[hardest_negative])

    return torch.stack(anchors), torch.stack(positives), torch.stack(negatives)

# Training Loop
def train_with_triplet_loss(model, dataloader, criterion, optimizer, num_epochs=10):
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0.0

        for data1, labels1 in dataloader:
            data, labels = data.cuda(), labels.cuda()

            # Forward pass
            embeddings = model(data)

            # Hard triplet mining
            anchors, positives, negatives = hard_triplet_mining(embeddings, labels)

            # Compute triplet loss
            loss = criterion(anchors, positives, negatives)
            epoch_loss += loss.item()

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")

# # Example Usage
# if __name__ == "__main__":
#     # Example dataset
#     data = [torch.rand(3, 224, 224) for _ in range(100)]  # Random images
#     labels = torch.tensor([random.randint(0, 4) for _ in range(100)])  # Random labels (5 classes)

    dataset = CustomDataset(data1, labels1)
    dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

    # Initialize model, loss, and optimizer
    model = FeatureExtractor().cuda()
    criterion = TripletLoss(margin=1.0)
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    # Train the model
    train_with_triplet_loss(model, dataloader, criterion, optimizer, num_epochs=10)
