<a href="https://colab.research.google.com/github/Tom271/MLforTerrainGeneration/blob/main/MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
bs = 100

# MNIST Dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transform, download=False)

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [Errno 110] Connection timed out>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 17.7MB/s]


Extracting ./mnist_data/MNIST/raw/train-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [Errno 110] Connection timed out>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 484kB/s]


Extracting ./mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [Errno 110] Connection timed out>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 4.44MB/s]


Extracting ./mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [Errno 110] Connection timed out>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 4.19MB/s]

Extracting ./mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw






In [None]:
class Generator(nn.Module):
    def __init__(self, g_input_dim, g_output_dim):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(g_input_dim, 256)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features*2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features*2)
        self.fc4 = nn.Linear(self.fc3.out_features, g_output_dim)

    # forward method
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        return torch.tanh(self.fc4(x))

class Discriminator(nn.Module):
    def __init__(self, d_input_dim):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(d_input_dim, 1024)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features//2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features//2)
        self.fc4 = nn.Linear(self.fc3.out_features, 1)

    # forward method
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.3)
        return torch.sigmoid(self.fc4(x))

In [None]:
# build network
z_dim = 100
mnist_dim = train_dataset.train_data.size(1) * train_dataset.train_data.size(2)

G = Generator(g_input_dim = z_dim, g_output_dim = mnist_dim).to(device)
D = Discriminator(mnist_dim).to(device)



In [None]:
G

Generator(
  (fc1): Linear(in_features=100, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=512, bias=True)
  (fc3): Linear(in_features=512, out_features=1024, bias=True)
  (fc4): Linear(in_features=1024, out_features=784, bias=True)
)

In [None]:
D

Discriminator(
  (fc1): Linear(in_features=784, out_features=1024, bias=True)
  (fc2): Linear(in_features=1024, out_features=512, bias=True)
  (fc3): Linear(in_features=512, out_features=256, bias=True)
  (fc4): Linear(in_features=256, out_features=1, bias=True)
)

In [None]:
# loss
criterion = nn.BCELoss()

# optimizer
lr = 0.0002
G_optimizer = optim.Adam(G.parameters(), lr = lr)
D_optimizer = optim.Adam(D.parameters(), lr = lr)

In [None]:
def D_train(x):
    #=======================Train the discriminator=======================#
    D.zero_grad()

    # train discriminator on real
    x_real, y_real = x.view(-1, mnist_dim), torch.ones(bs, 1)
    x_real, y_real = Variable(x_real.to(device)), Variable(y_real.to(device))

    D_output = D(x_real)
    D_real_loss = criterion(D_output, y_real)
    D_real_score = D_output

    # train discriminator on facke
    z = Variable(torch.randn(bs, z_dim).to(device))
    x_fake, y_fake = G(z), Variable(torch.zeros(bs, 1).to(device))

    D_output = D(x_fake)
    D_fake_loss = criterion(D_output, y_fake)
    D_fake_score = D_output

    # gradient backprop & optimize ONLY D's parameters
    D_loss = D_real_loss + D_fake_loss
    D_loss.backward()
    D_optimizer.step()

    return  D_loss.data.item()

In [None]:
def G_train(x):
    #=======================Train the generator=======================#
    G.zero_grad()

    z = Variable(torch.randn(bs, z_dim).to(device))
    y = Variable(torch.ones(bs, 1).to(device))

    G_output = G(z)
    D_output = D(G_output)
    G_loss = criterion(D_output, y)

    # gradient backprop & optimize ONLY G's parameters
    G_loss.backward()
    G_optimizer.step()

    return G_loss.data.item()

In [None]:
n_epoch = 50
for epoch in range(1, n_epoch+1):
    D_losses, G_losses = [], []
    for batch_idx, (x, _) in enumerate(train_loader):
        D_losses.append(D_train(x))
        G_losses.append(G_train(x))

    print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
            (epoch), n_epoch, torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_losses))))

[1/50]: loss_d: 0.881, loss_g: 3.330
[2/50]: loss_d: 1.180, loss_g: 1.399
[3/50]: loss_d: 1.121, loss_g: 1.462
[4/50]: loss_d: 0.962, loss_g: 2.014
[5/50]: loss_d: 0.746, loss_g: 1.951
[6/50]: loss_d: 0.689, loss_g: 1.927
[7/50]: loss_d: 0.676, loss_g: 2.075
[8/50]: loss_d: 0.617, loss_g: 2.160
[9/50]: loss_d: 0.667, loss_g: 2.099
[10/50]: loss_d: 0.703, loss_g: 2.064
[11/50]: loss_d: 0.730, loss_g: 1.999
[12/50]: loss_d: 0.758, loss_g: 1.950
[13/50]: loss_d: 0.789, loss_g: 1.889
[14/50]: loss_d: 0.736, loss_g: 1.996
[15/50]: loss_d: 0.730, loss_g: 2.007
[16/50]: loss_d: 0.762, loss_g: 1.895
[17/50]: loss_d: 0.776, loss_g: 1.909
[18/50]: loss_d: 0.894, loss_g: 1.641
[19/50]: loss_d: 0.865, loss_g: 1.661
[20/50]: loss_d: 0.852, loss_g: 1.720
[21/50]: loss_d: 0.862, loss_g: 1.669
[22/50]: loss_d: 0.926, loss_g: 1.549
[23/50]: loss_d: 0.935, loss_g: 1.511
[24/50]: loss_d: 0.931, loss_g: 1.515
[25/50]: loss_d: 0.974, loss_g: 1.448
[26/50]: loss_d: 0.951, loss_g: 1.465
[27/50]: loss_d: 0.97

In [None]:
import os

# ... (rest of your code) ...

# Before calling save_image, create the 'samples' directory if it doesn't exist
os.makedirs('./samples', exist_ok=True)

with torch.no_grad():
    test_z = Variable(torch.randn(bs, z_dim).to(device))
    generated = G(test_z)

    save_image(generated.view(generated.size(0), 1, 28, 28), './samples/sample_' + '.png')

In [None]:
from google.colab import files

import helper


torch.save(G.state_dict(), 'generator.pt')
torch.save(D.state_dict(), 'discriminator.pt')

files.download('generator.pt')
files.download('discriminator.pt')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
G.state_dict = torch.load('/content/drive/MyDrive/My Folder/models/generator.pth')
print(G.state_dict.keys())

D.state_dict = torch.load('/content/drive/MyDrive/My Folder/models/discriminator.pth')
print(D.state_dict.keys())

G.load_state_dict(G.state_dict)
D.load_state_dict(D.state_dict)

  G.state_dict = torch.load('/content/drive/MyDrive/My Folder/models/generator.pth')


odict_keys(['fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias', 'fc4.weight', 'fc4.bias'])


  D.state_dict = torch.load('/content/drive/MyDrive/My Folder/models/discriminator.pth')


odict_keys(['fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias', 'fc4.weight', 'fc4.bias'])


<All keys matched successfully>

In [None]:
# prompt: continue training the model for another 50 epochs

n_epoch = 50
for epoch in range(1, n_epoch+1):
    D_losses, G_losses = [], []
    for batch_idx, (x, _) in enumerate(train_loader):
        D_losses.append(D_train(x))
        G_losses.append(G_train(x))

    print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
            (epoch), n_epoch, torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_losses))))

[1/50]: loss_d: 1.171, loss_g: 1.065
[2/50]: loss_d: 1.156, loss_g: 1.087
[3/50]: loss_d: 1.174, loss_g: 1.063
[4/50]: loss_d: 1.184, loss_g: 1.047
[5/50]: loss_d: 1.176, loss_g: 1.054
[6/50]: loss_d: 1.178, loss_g: 1.046
[7/50]: loss_d: 1.198, loss_g: 1.020
[8/50]: loss_d: 1.195, loss_g: 1.019
[9/50]: loss_d: 1.197, loss_g: 1.016
[10/50]: loss_d: 1.197, loss_g: 1.021
[11/50]: loss_d: 1.194, loss_g: 1.022
[12/50]: loss_d: 1.211, loss_g: 0.996
[13/50]: loss_d: 1.213, loss_g: 0.992
[14/50]: loss_d: 1.214, loss_g: 0.994
[15/50]: loss_d: 1.209, loss_g: 1.002
[16/50]: loss_d: 1.217, loss_g: 0.991
[17/50]: loss_d: 1.212, loss_g: 0.986
[18/50]: loss_d: 1.215, loss_g: 0.997
[19/50]: loss_d: 1.215, loss_g: 0.996
[20/50]: loss_d: 1.212, loss_g: 0.996
[21/50]: loss_d: 1.216, loss_g: 0.982
[22/50]: loss_d: 1.217, loss_g: 0.981
[23/50]: loss_d: 1.223, loss_g: 0.978
[24/50]: loss_d: 1.224, loss_g: 0.980
[25/50]: loss_d: 1.224, loss_g: 0.970
[26/50]: loss_d: 1.233, loss_g: 0.968
[27/50]: loss_d: 1.23

In [None]:
torch.load('/content/drive/MyDrive/My Folder/models/generator.pth')


torch.load('/content/drive/MyDrive/My Folder/models/discriminator.pth')


  torch.load('/content/drive/MyDrive/My Folder/models/generator.pth')
  torch.load('/content/drive/MyDrive/My Folder/models/discriminator.pth')


OrderedDict([('fc1.weight',
              tensor([[-0.0385,  0.0255,  0.0014,  ...,  0.0012,  0.0294, -0.0178],
                      [-0.0200,  0.0281, -0.0315,  ...,  0.0066, -0.0217, -0.0235],
                      [ 0.0071,  0.0094, -0.0012,  ...,  0.0033,  0.0066, -0.0258],
                      ...,
                      [-0.0014, -0.0232,  0.0190,  ..., -0.0072,  0.0733, -0.0065],
                      [ 0.0049,  0.0268, -0.0631,  ...,  0.0519,  0.0049,  0.0247],
                      [-0.0159, -0.0110, -0.0161,  ..., -0.0119,  0.0458, -0.0378]],
                     device='cuda:0')),
             ('fc1.bias',
              tensor([ 0.2752,  0.4087,  0.2221,  ...,  0.3341,  0.0101, -0.2173],
                     device='cuda:0')),
             ('fc2.weight',
              tensor([[ 0.0161, -0.0546,  0.0003,  ..., -0.0900, -0.0618,  0.1365],
                      [ 0.0051,  0.0046,  0.0230,  ...,  0.0585,  0.0371,  0.0512],
                      [ 0.0387,  0.0816, -0.0331,  ...,