In [2]:
import torch
import random 
from tqdm import tqdm  # A library for displaying progress Bars
from torch.utils.data import Dataset
import torch.nn as nn
import math
from PIL import Image  #Pytorch Image Library
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import os
import numpy as np
from torch.utils.data import DataLoader
import sys
from torch.utils.tensorboard import SummaryWriter
import argparse
import subprocess
from datetime import datetime
from shutil import copyfile
from stat import S_IREAD, S_IRGRP, S_IROTH


device="cuda"

def _m(a,max_depth):  # This function calculates the Mandelbrot value for a given complex number a to a max depth of 'max_depth'
    z = 0
    for n in range(max_depth):
        z = z**2 + a
        if abs(z)>2:
            return smoothMandelbrot(n)
        
    return 1.0

# This function calculates a smooth approximation of the Mandelbrot value based on the number of iterations(iters).
# It uses a smoothness parameter to control the smoothness of the result
def smoothMandelbrot(iters, smoothness=50):
    return 1-(1/((iters/smoothness)+1))

# This function determines whether a given point is in the Mandelbrot set. 
def mandelbrot(x,y, max_depth=50):
    return _m(x+1j*y, max_depth)

#This function calculates the MandelBrot set values for a grid of complex numbers within a specified region
#of the complex plane.
def mandelbrotGPU(resx, resy, xmin, xmax, ymin, ymax, max_depth):
    X = torch.linspace(xmin, xmax, resx, device=device, dtype=torch.float64)
    Y = torch.linspace(ymin, ymax, resy, device=device, dtype=torch.float64)

    #Creating the meshgrid using real and imaginary ranges
    imag_values, real_values = torch.meshgrid(Y,X)
    return mandelbrotTensor(imag_values, real_values, max_depth)


#This function calculates the Mandelbrot set for a grid of complex number defined by 'imag_values'
#and real values.
def mandelbrotTensor(imag_values, real_values, max_depth):
    #combine real and imaginary parts into a complex tensor
    c = real_values + 1j * imag_values
    z = torch.zeros_like(c, dtype=torch.float64, device = device)
    mask = torch.ones_like(z, dtype=torch.bool, device = device)
    final_image = torch.zeros_like(z, dtype=torch.float64, device=device)

    for n in range(max_depth):
        z = z**2 + c
        escaped = torch.abs(z) > 2
        mask = ~escaped & mask
        #print(n, smoothMandelbrot(n), torch.tensor([smoothMandelbrot(n)], dtype=torch.float64).cuda().cpu().numpy()[0])
        final_image[mask] = smoothMandelbrot(n)

    final_image[torch.abs(z) <= 2] = 1.0 # all points that never escaped should be set to full white
    return final_image

# This class is used to create a dataset of randomized points and their corresponding Mandelbrot values.
class MandelbrotDataSet(Dataset):
    def __init__(self, size=1000, loadfile=None, max_depth=50, xmin=-2.5, xmax=1.0, ymin=-1.1, ymax=1.1, dtype=torch.float32, gpu=False):
        self.inputs = []
        self.outputs = []
        if loadfile is not None:
            self.load(loadfile)
        else:
            print("Generating Dataset")
            if not gpu:
                for _ in tqdm(range(size)):
                    x = random.uniform(xmin, ymax)
                    y = random.uniform(ymin, ymax)
                    self.inputs.append(torch.tensor([x,y]))
                    self.outputs.append(torch.tensor[mandelbrot(x,y,max_depth)])
                self.inputs = torch.stack(self.inputs)
                self.outputs = torch.stack(self.outputs)
            else:
                X = (xmin-xmax) * torch.rand((size), dtype=dtype, device=device) + xmax
                Y = (ymin-ymax) * torch.rand((size), dtype=dtype, device=device) + xmax
                self.inputs = torch.stack([X,Y], dim=1).cpu()
                self.outputs = mandelbrotTensor(Y,X, max_depth).cpu()
        self.start_oversample(len(self.inputs))
    
    def __getitem__(self, i):
        if( i >= len(self.inputs)):
            ind = self.oversample_inidces[i-len(self.inputs)]
            return self.inputs[ind], self.output[ind], ind.item()
        return self.inputs[i], self.outputs[i], i

    def __len__(self):
        return len(self.inputs) + len(self.oversample_indices)
    
    def start_oversample(self, max_size):
        self.max_size = max_size
        self.oversample_indices = torch.tensor([], dtype=torch.long)
        self.oversample_buffer = torch.tensor([], dtype=torch.long)

    def update_oversample(self):
        self.oversample_indices = self.oversample_buffer[:self.max_size]
        self.oversample_buffer = torch.tensor([], dtype=torch.long)

    def add_oversample(self, indices):
        indices = indices[indices < len(self.inputs)] # remove duplicates
        self.oversample_buffer = torch.cat([self.oversample_buffer, indices], 0)

    def save(self, filename):
        import os
        os.makedirs("./data", exist_ok=True)
        torch.save(self.inputs, './data/'+filename+'_inputs.pt')
        torch.save(self.outputs, './data/'+filename+'_outputs.pt')

    def load(self, filename):
        self.inputs = torch.load('./data/'+filename+'_inputs.pt')
        self.outputs = torch.load('./data/'+filename+'_outputs.pt')

        


In [3]:
class ImageDataset(Dataset):
    def __init__(self, image_path):
        """Load Image, convert to grayscale and scale pixel values to [0,1]"""
        self.image = Image.open(image_path).convert('L')
        self.image = ToTensor()(self.image)
        
        """Get Image Dimensions"""
        self.height, self.width = self.image.shape[1:]

    def __len__(self):
        return self.height * self.width     # Length of the tensor
    
    def __getitem__(self, idx):
        """Convert flat index to 2D coordinates"""
        row = idx // self.width
        col = idx % self.width

        """Scale coordinates to [-1, 1]"""
        input = torch.tensor([col/(self.width/2)-1, (self.height-row)/(self.height/2)-1])

        """Gete pixel value"""
        output = self.image[0, row, col]

        return input, output
    
    def display_image(self):
        """Uses the getitem method to get each pixel value and displays the final image, used for debugging purposes"""
        image = torch.zeros((self.height, self.width))
        for i in range(len(self)):
            row = i // self.width
            col = i % self.width
            image[row, col] = self[i][1]
        plt.imshow(image, cmap='gray')
        plt.show()

In [4]:
class Simple(nn.Module):
    def __init__(self, hidden_size=100, num_hidden_layers = 7, init_size=2):
        super(Simple, self).__init__()
        layers = [nn.Linear(init_size, hidden_size)],nn.ReLU()
        for _ in range(num_hidden_layers):
            layers.append(nn.Linear(hidden_size, hidden_size))
            layers.append(nn.ReLU())
        layers.append(nn.Linear(hidden_size, 1))

        """layers.append(nn.Sigmoid())"""
        self.tanh = nn.Tanh()
        self.seq = nn.Sequential(*layers)

    def forward(self, x):
        return (self.tanh(self.seq(x))+1)/2


class SkipConn(nn.Module):
    def __init__(self, hidden_size=100, num_hidden_layers=7, init_size=2, linmap=None):
        super(SkipConn, self).__init__()
        out_size = hidden_size

        self.inLayer = nn.Linear(init_size, out_size)
        self.relu = nn.LeakyReLU()
        hidden = []
        for i in range(num_hidden_layers):
            in_size = out_size*2 + init_size if i>0 else out_size + init_size
            hidden.append(nn.Linear(in_size, out_size))
        self.hidden = nn.ModuleList(hidden)
        self.outLayer = nn.Linear(out_size*2+init_size, 1)
        self.tanh = nn.Tanh()
        self.sig = nn.Sigmoid()
        self._linmap = linmap

    def forward(self, x):
        if self.__linmap:
            x = self._linmap.map(x)
        cur = self.relu(self.inLayer(x))
        prev = torch.tensor([]).cuda()
        for layer in self.hidden:
            combined = torch.cat([cur, prev, x], 1)
            prev = cur
            cur = self.relu(layer(combined))
        y = self.outLayer(torch.cat([cur, prev,x], 1))
        return (self.tanh(y)+1)/2

class Fourier(nn.Module):
    def __init__(self, fourier_order=4, hidden_size=100, num_hidden_layers=7, linmap=None):
        super(Fourier, self).__init__()
        self.fourier_ordr = fourier_order
        self.inner_model = SkipConn(hidden_size, num_hidden_layers, fourier_order*4+2)
        self._linmap = linmap
        self.orders = torch.arange(1, fourier_order + 1).float().to('cuda')

    def forward(self, x):
        if self._linmap:
            x = self._linmap.map(x)
        x = x.unsqueeze(-1)
        fourier_features = torch.cat([torch.sin(self.orders * x), torch.cos(self.orders * x), x], dim=-1)
        fourier_features = fourier_features.view(x.shape[0], -1) 
        return self.inner_model(fourier_features)
    
class Fourier2D(nn.Module):
    def __init__(self, fourier_order=4, hidden_size=100, num_hidden_layers=7, linmap = None):
        super(Fourier2D, self).__init__()
        self.fourier_order = fourier_order
        self.inner_model = SkipConn(hidden_size, num_hidden_layers, (fourier_order * fourier_order * 4) + 2)
        self._linmap = linmap
        self.orders = torch.arange(0, fourier_order).float().to('cuda')

    def forward(self, x):
        if self._linmap:
            x = self._linmap.map(x)
        features = [x]
        for n in self.orders: 
            for m in self.orders:
                features.append((torch.cos(n*x[:,0])*torch.cos(m*x[:,1])).unsqueeze(-1))
                features.append((torch.cos(n*x[:,0])*torch.sin(m*x[:,1])).unsqueeze(-1))
                features.append((torch.sin(n*x[:,0])*torch.cos(m*x[:,1])).unsqueeze(-1))
                features.append((torch.sin(n*x[:,0])*torch.sin(m*x[:,1])).unsqueeze(-1))
        fourier_features = torch.cat(features, 0)
        return self.inner_model(fourier_features)    
        
    
class CenteredLinearMap():
    def __init__(self, xmin=-2.5, xmax=1.0, ymin=-1.1, ymax=1.1, x_size=None, y_size=None):
        if x_size is not None:
            x_m = x_size/(xmax - xmin)
        else:
            x_m = 1.
        if y_size is not None:
            y_m = y_size/(ymax - ymin)
        else:
            y_m = 1.
        x_b = -(xmin + xmax)*x_m/2 - 1 
        y_b = -(ymin + ymax)*y_m/2
        self.m = torch.tensor([x_m, y_m], dtype = torch.float)
        self.b = torch.tensor([x_b, y_b], dtype = torch.float)

    def map(self, x):
        m = self.m.cuda()
        b = self.b.cuda()                                                  
        return m*x + b

class Taylor(nn.Module):
    def __init__(self, taylor_order=4, hidden_size=100,num_hidden_layers = 7, linmap = None):
        super(Taylor, self).__init__()
        self.taylor_order = taylor_order
        self._linmap = linmap
        self.inner_model = SkipConn(hidden_size, num_hidden_layers, taylor_order*2 + 2)

    def forward(self, x):
        if self._linmap:
            x = self._linmap.map(x)
        series = [x]
        for n in range(1, self.taylor_order + 1):
            series.append(x**n)
        taylor = torch.cat(series, 1)
        return self.inner_model(taylor)


In [5]:
os.makedirs("./captures/images", exist_ok=True)

def renderMandelbrot(resx, resy, xmin=-2.4, xmax=1, yoffset=0, max_depth=50, gpu=False):
    step_size = (xmax-xmin)/resx
    y_start = step_size * resy/2
    ymin = -y_start-yoffset
    ymax = y_start-yoffset
    if not gpu:
        X = np.arange(xmin, xmax, step_size)[:resx]
        Y = np.arange(ymin, ymax, step_size)[:resy]
        im = np.zeros((resy, resx))
        for j, x in enumerate(tqdm(X)):
            for i, y in enumerate(Y):
                im[i, j] = mandelbrot(x,y, max_depth)
        return im
    else:
        return mandelbrotGPU(resx, resy, xmin, xmax, ymin, ymax, max_depth).cpu().numpy()
    
def renderModel(model, resx, resy, xmin=-2.4, xmax=1, yoffset=0, linspace=None, max_gpu=False):
    with torch.no_grad():
        model.eval()
        if linspace is None:
            linspace = generateLinspace(resx, resy, xmin, xmax, yoffset)
        linspace = linspace

        if not max_gpu:
            im_slices = []
            for points in linspace:
                im_slices.append(model(points))
            im = torch.stack(im_slices, 0)
        else:
            if linspace.shape != (resx*resy, 2):
                linspace = torch.reshape(linspace, (resx*resy, 2))
            im = model(linspace).squeeze()
            im = torch.reshape(im, (resy, resx))

        im = torch.clamp(im, 0 , 1)
        linspace = linspace.cpu()
        torch.cuda.empty_cache()
        model.train()
        return im.squeeze().cpu().numpy()
    
def generateLinspace(resx, resy, xmin=-2.4, xmax=1, yoffset=0):
    iteration = (xmax-xmin)/resx
    X = torch.arange(xmin, xmax, iteration).cuda()[:resx]
    y_max = iteration * resy/2
    Y = torch.arange(-y_max-yoffset, y_max-yoffset, iteration)[:resy]
    linspace = []
    for y in Y:
        ys = torch.ones(len(X)).cuda() * y
        points = torch.stack([X, ys], 1)
        linspace.append(points)
    return torch.stack(linspace, 0)

class VideoMaker:
    def __init__(self, name='autosave', fps=30, dims=(100,100), capture_rate=10, shots=None,
                 max_gpu=False, cmap='magma'):
        self.name = name
        self.dims = dims
        self.capture_rate = capture_rate
        self.max_gpu = max_gpu
        self._xmin = -2.4
        self._xmax = 1
        self._yoffset = 0
        self.shots = 0
        self._yoffset = 0
        self.shots = shots
        self.cmap = cmap
        self.fps = fps
        os.makedirs(f'./frames/{self.name}', exist_ok = True)

        self.linspace = generateLinspace(self.dims[0], self.dims[1], self._xmin, self._xmax, self._yoffset)
        if max_gpu:
            self.linspace = torch.reshape(self.linspace, (dims[0]*dims[1], 2))

        self.frame_count = 0

    def generateFrame (self, model):
        if self.shots is not None and len(self.shots) > 0 and self.frame_count >= self.shots[0]['frame']:
            shot = self.shots.pop(0)
            self._xmin = shot["xmin"]
            self._xmax = shot["xmax"]
            self._yoffset = shot["yoffset"]
            if len(shot) > 4:
                self.capture_rate = shot["capture_rate"]
            self.linspace = generateLinspace(self.dims[0], self.dims[1], self._xmin, self._xmax, self._yoffset)

        im = renderModel(model, self.dims[0], self.dims[1], linspace=self.linspace, max_gpu = self.max_gpu)
        plt.imsave(f'./frames/{self.name}/{self.frame_count:05d}.png', im, cmap = self.cmap)

    def generateVideo(self):
        os.system(f'ffmpeg -y -r {self.fps} -i ./frames/{self.name}%05d.png -c:v libx264 -preset veryslow -crf 0 -pix_fmt yuv420p ./frames/{self.name}/self.name.mp4')

In [6]:
class Logger:
    def __init__(self, file, dir=None):
        if dir is None:
            self.dir = './results/'+datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
            os.makedirs(self.dir, exist_ok=True)

        else:
            self.dir = './'+dir
        
        self.copyFile(file)
        self.console = sys.stdout
        self.results = open(self.dir + "/results.txt", 'a+')
        sys.stdout = self
        print('LOGGING RUN RESULTS')

    def copyFile(self, file):
        copy_file = os.path.join(self.dir, os.path.basename(file))
        copyfile(file, copy_file)
        os.chmod(copy_file, S_IREAD|S_IRGRP|S_IROTH)

    def createDir(self, dir_name):
        os.makedirs(os.path.join(self.dir, dir_name), exist_ok=True)

    def write(self,message):
        self.console.write(message)
        self.results.write(message)

    def flush(self):
        self.console.flush()

    def __del__(self):
        self.results.close()
            


In [7]:
os.makedirs("./models", exist_ok = True)

def train(model, dataset, epochs, batch_size=1000, use_scheduler=False, oversample=0, eval_dataset=None, savemodelas='autosave.pt', snapshots_every=-1, vm=None):
    print("Initializing...")
    tb = SummaryWriter()
    logger = Logger(__file__, dir=tb.log_dir)
    logger.copyFile(sys.argv[0])

    logger.createDir('images')
    logger.createDir('models')

    model = model
    optim = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-5)
    if use_scheduler:
        scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=5, gamma=0.5)

    if oversample != 0:
        per_batch = math.floor(batch_size * oversample)
        dataset.start_oversample(math.floor(len(dataset)*oversample))

    print('Training...')
    avg_losses = []
    tot_iterations = 0
    if eval_dataset is not None:
        tb.add_scaler('Loss/eval', evaluate(model, eval_dataset, batch_size), tot_iterations)

    for epoch in range(epochs):
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        loop = tqdm(total=len(loader), position=0)
        tot_loss = 0

        for i, (inputs, outputs,indices) in enumerate(loader):
            if vm is not None and tot_iterations%vm.capture_rate==0:
                vm.generateFrame(model)
            inputs, outputs = inputs, outputs

            optim.zero_grad()

            pred = model(inputs).squeeze()
            pred, outputs = pred.float(), outputs.float()

            all_losses = torch.abs(outputs - pred)

            if oversample != 0:
                size = per_batch if per_batch < len(all_losses) else len(all_losses)
                highest_loss = torch.topk(all_losses, size)
                selected_indices = highest_loss[1].cpu()
                dataset.add_oversample(indices[selected_indices])

            loss = torch.mean(all_losses)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.01)
            optim.step()
            tot_loss += loss.item()

            loop.set_description('epoch:{:d} Loss:{:.6f}'.format(epoch, tot_loss/i+1))
            loop.update(1)
            tb.add_scalar('Loss/Train', loss.detach().item(), tot_iterations)
            tb.add_scalar('Learning Rate', scheduler.get_last_lr()[0], tot_iterations)
            tot_iterations += 1
            inputs, outputs = inputs.cpu(), outputs.cpu()

            if snapshots_every != -1 and tot_iterations%snapshots_every == 0:
                tb.add_image('sample', renderModel(model, 960, 544, max_gpu=True).data, tot_iterations)
                if eval_dataset is not None:
                    tb.add_scalar('Loss/eval', evaluate(model, eval_dataset, batch_size), tot_iterations)
        
        loop.close()
        avg_losses.append(tot_loss/len(loader))

        if use_scheduler:
            scheduler.step()
        dataset.update_oversample()

        if savemodelas is not None:
            torch.save(model.state_dict(), './models/'+savemodelas)
    print("Finished Training.")
    print("Final learning rate: ", scheduler.get_last_lr()[0])
    if eval_dataset is not None:
        tb.add_scalar('Loss/eval', evaluate(model, eval_dataset, batch_size), tot_iterations)

    
    if vm is not None:
        print("Finalizing Capture...")
        vm.generateFrame(model)
        vm.generateVideo()
    if savemodelas is not None:
        print('Sving...')
        torch.save(model.state_dict(), './models/'+savemodelas)
    print("Done.")
    plt.show()
    tb.close()

def evaluate(model, eval_dataset, batch_size):
    model.eval()
    with torch.no_grad():
        loader = DataLoader(eval_dataset, batch_size=batch_size)
        tot_loss = 0
        for i, (inputs,outputs,indices) in enumerate(loader):
            inputs, outputs = inputs, outputs
            pred = model(inputs).squeeze()
            pred, outputs = pred.float(), outputs.float()

            loss = torch.mean(torch.abs(outputs - pred))
            tot_loss = loss.item()
    model.train()
    return tot_loss/len(loader)

In [7]:
os.makedirs("./models", exist_ok = True)

def train(model, dataset, epochs, batch_size=1000, use_scheduler=False, oversample=0, eval_dataset=None, savemodelas='autosave.pt', snapshots_every=-1, vm=None):
    print("Initializing...")
    tb = SummaryWriter()
    logger = Logger(__file__, dir=tb.log_dir)
    logger.copyFile(sys.argv[0])

    logger.createDir('images')
    logger.createDir('models')

    model = model
    optim = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-5)
    if use_scheduler:
        scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=5, gamma=0.5)

    if oversample != 0:
        per_batch = math.floor(batch_size * oversample)
        dataset.start_oversample(math.floor(len(dataset)*oversample))

    print('Training...')
    avg_losses = []
    tot_iterations = 0
    if eval_dataset is not None:
        tb.add_scaler('Loss/eval', evaluate(model, eval_dataset, batch_size), tot_iterations)

    for epoch in range(epochs):
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        loop = tqdm(total=len(loader), position=0)
        tot_loss = 0

        for i, (inputs, outputs,indices) in enumerate(loader):
            if vm is not None and tot_iterations%vm.capture_rate==0:
                vm.generateFrame(model)
            inputs, outputs = inputs, outputs

            optim.zero_grad()

            pred = model(inputs).squeeze()
            pred, outputs = pred.float(), outputs.float()

            all_losses = torch.abs(outputs - pred)

            if oversample != 0:
                size = per_batch if per_batch < len(all_losses) else len(all_losses)
                highest_loss = torch.topk(all_losses, size)
                selected_indices = highest_loss[1].cpu()
                dataset.add_oversample(indices[selected_indices])

            loss = torch.mean(all_losses)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.01)
            optim.step()
            tot_loss += loss.item()

            loop.set_description('epoch:{:d} Loss:{:.6f}'.format(epoch, tot_loss/i+1))
            loop.update(1)
            tb.add_scalar('Loss/Train', loss.detach().item(), tot_iterations)
            tb.add_scalar('Learning Rate', scheduler.get_last_lr()[0], tot_iterations)
            tot_iterations += 1
            inputs, outputs = inputs.cpu(), outputs.cpu()

            if snapshots_every != -1 and tot_iterations%snapshots_every == 0:
                tb.add_image('sample', renderModel(model, 960, 544, max_gpu=True).data, tot_iterations)
                if eval_dataset is not None:
                    tb.add_scalar('Loss/eval', evaluate(model, eval_dataset, batch_size), tot_iterations)
        
        loop.close()
        avg_losses.append(tot_loss/len(loader))

        if use_scheduler:
            scheduler.step()
        dataset.update_oversample()

        if savemodelas is not None:
            torch.save(model.state_dict(), './models/'+savemodelas)
    print("Finished Training.")
    print("Final learning rate: ", scheduler.get_last_lr()[0])
    if eval_dataset is not None:
        tb.add_scalar('Loss/eval', evaluate(model, eval_dataset, batch_size), tot_iterations)

    
    if vm is not None:
        print("Finalizing Capture...")
        vm.generateFrame(model)
        vm.generateVideo()
    if savemodelas is not None:
        print('Sving...')
        torch.save(model.state_dict(), './models/'+savemodelas)
    print("Done.")
    plt.show()
    tb.close()

def evaluate(model, eval_dataset, batch_size):
    model.eval()
    with torch.no_grad():
        loader = DataLoader(eval_dataset, batch_size=batch_size)
        tot_loss = 0
        for i, (inputs,outputs,indices) in enumerate(loader):
            inputs, outputs = inputs, outputs
            pred = model(inputs).squeeze()
            pred, outputs = pred.float(), outputs.float()

            loss = torch.mean(torch.abs(outputs - pred))
            tot_loss = loss.item()
    model.train()
    return tot_loss/len(loader)

In [None]:
def parse_args():
    parser = argparse.ArgumentParser(description="Generate a Mandelbrot set zoom Video.")
    parser.add_argument("resx", type=int, help="Width of the image.")
    parser.add_argument("resy", type=int, help="Height of the image.")
    parser.add_argument("frames", type=int, default=100, help="Number of frames")
    parser.add_argument("xmin", type=float, default=-2.4, help="Minimum x value in the 2d space.")
    parser.add_argument("xmax", type=float, default=1, help="Maximum value in the 2d space.")
    parser.add_argument("--yoffset", type=float, default=0.5, help="Y offset." )
    parser.add_argument("--max_depth", type=int, default=500, help="Max depth param for mandelbrot functions.")
    parser.add_argument("--zoom_speed", type=float, default=0.05, help="The fraction by which to zoom in each frame.")
    parser.add_argument("--video_name", type=str, default="mandelbrot", help="The name of the video.")

def main():
    args = parse_args()
    frames = 100
    max_depth=500
    zoom_speed=0.05
    resx1=3840
    resy1=2160
    video_name=mandelbrot
    frames_dir = f'frames/{video_name}'
    if not os.path.exists(frames_dir):
        os.makedirs(frames_dir)

    xmin, xmax = -2.4, 1
    yoffset = 0.5

    for i in tqdm(range(frames)):
        image = renderMandelbrot(resx=resx1,resy=resy1, xmin = xmin, xmax = xmax, yoffset = yoffset, max_depth =max_depth, gpu=False)
        plt.imsave(f'{frames_dir}/frame_{i:03d}.png', image, vmin=0, vmax=1, cmap='gist_heat')

        x_range = xmax - xmin
        xmin += zoom_speed * x_range/2
        xmax -= zoom_speed * x_range/2

    video_name = f'{frames_dir}/{video_name}.mp4'
    command = f'ffmpeg -framerate 60 -i {frames_dir}/frame_%03d.png -c:v libx264 -pix_fmt yuv420p -crf 20 -preset slow {video_name}'
    subprocess.run(command, shell=True, check = True)

if __name__ == "__main__":
    main()

  0%|          | 0/100 [00:00<?, ?it/s]
  0%|          | 0/3840 [00:00<?, ?it/s][A
  0%|          | 12/3840 [00:00<00:34, 111.01it/s][A
  1%|          | 24/3840 [00:00<00:33, 113.12it/s][A
  1%|          | 36/3840 [00:00<00:33, 112.55it/s][A
  1%|▏         | 48/3840 [00:00<00:33, 113.34it/s][A
  2%|▏         | 60/3840 [00:00<00:33, 114.24it/s][A
  2%|▏         | 72/3840 [00:00<00:32, 115.19it/s][A
  2%|▏         | 84/3840 [00:00<00:32, 115.31it/s][A
  2%|▎         | 96/3840 [00:00<00:32, 114.51it/s][A
  3%|▎         | 108/3840 [00:00<00:32, 114.74it/s][A
  3%|▎         | 120/3840 [00:01<00:32, 114.62it/s][A
  3%|▎         | 132/3840 [00:01<00:32, 114.92it/s][A
  4%|▍         | 144/3840 [00:01<00:32, 115.37it/s][A
  4%|▍         | 156/3840 [00:01<00:31, 116.07it/s][A
  4%|▍         | 168/3840 [00:01<00:31, 115.91it/s][A
  5%|▍         | 180/3840 [00:01<00:31, 115.47it/s][A
  5%|▌         | 192/3840 [00:01<00:31, 115.65it/s][A
  5%|▌         | 204/3840 [00:01<00:31, 116.

In [None]:
class ELM(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(ELM, self).__init__()
        self.hidden_layer = nn.Linear(input_dim, hidden_dim)
        self.output_layer = nn.Linear(hidden_dim, output_dim, bias=False)
        self.activation = nn.SELU()

        nn.init.uniform_(self.hidden_layer.weight)
        self.hidden_layer.weight.requires_grad = False

    def forward(self, x):
        x = self.activation(self.hidden_layer(x))
        x = self.output_layer(x)
        return x
    
dataset = ImageDataset('/kaggle/input/images/DatasetImages/blob.png')
x = torch.stack([dataset[i][0] for i in range(len(dataset))])
y = torch.unsqueeze(torch.stack([dataset[i][1] for i in range(len(dataset))]), 1)
print(x.shape, y.shape)

model = ELM(2,1000, 1)

def evaluate_model(model, x, y):
    with torch.no_grad():
        outputs = model(x)
        mse = ((outputs - y) ** 2).mean().item()
    return mse

print('Before training, MSE: {:f}'.format(evaluate_model(model, x, y)))

H = model.hidden_layer(x)
H = model.activation(H)

print(H.shape)
H_pinv = torch.linalg.pinv(H)
print(H_pinv.shape, y.shape)
output_weights = torch.mm(H_pinv, y)

model.output_layer.weight.data = output_weights.view(model.output_layer.weight.data.size())

print('After Training, MSE: {:f}'.format(evaluate_model(model, x, y)))

model

resx, resy = dataset.width, dataset.height
linspace = torch.stack(torch.meshgrid(torch.linspace(-1,1,resx), torch.linspace(1,-1,resy)), dim=-1).cuda()
linspace = torch.rot90(linspace, 1, (0,1))
plt.imshow(renderModel(model, resx=resx, resy=resy, linspace=linspace), cmap='magma', origin='lower')
plt.show()

In [None]:
def fourier2d(x, order):
    features = []
    for n in range (1, order + 1):
        for m in range(1, order+1):
            f = torch.tensor([torch.cos(n*x[0])*torch.cos(m*x[1]), torch.cos(n*x[0])*torch.sin(m*x[1]), torch.sin(n*x[0])*torch.cos(m*x[1]), torch.sin(n*x[0])*torch.sin(m*x[1])])
            features.append(f)
    return torch.cat(features)

print(fourier2d(torch.tensor([1,3]),10))

In [None]:
def example_render():
    image = renderMandelbrot(3840, 2160, max_depth=500, gpu=True)
    plt.imsave('./captures/images/mandel_gpu.png', image, vmin=0, vmax=1, cmap='gist_heat')

def example_train():
    print("Initializing Model...")
    model = models.SkipConn(300,50).cuda()

    dataset = MandelbrotDataSet(2000000, gpu=True)
    eval_dataset = MandelbrotDataSet(100000, gpu=True)

    train(model, dataset, 10, batch_size=10000, eval_dataset=eval_dataset, oversample=0.1, use_scheduler=True, snapshots_every=50)

def example_render_model():
    linmap = models.CenteredLinearMap(x_size=torch.pi*2, y_size=torch.pi*2)
    model = models.Fourier(256, 400, 50, linmap=linmap)
    model.load_state_dict(torch.load('./models/Jun04_00-34-51_xerxes-u.pt'))
    model.cuda()
    image = renderModel(model, 7680, 4320, max_gpu=False)
    plt.imsave('./captures/images/Jun04_00-34-51_xerxes-u.png', image, vmin=0, vmax=1, cmap='inferno')
    plt.show()

def example_train_capture():
    shots = [
        {'frame':5, "xmin":-2.5, "xmax":1, "yoffset":0, "capture_rate":8},
        {'frame':10, "xmin":-1.8, "xmax":-0.9, "yoffset":0.2, "capture_rate":16},
    ]

    vidmaker = VideoMaker('test', dims=(960,544), capture_rate=5, shots=shots, max_gpu=True)
    linmap = models.CenteredLinearMap(x_size=torch.pi*2, y_size=torch.pi*2)
    model = models.Fourier(256,400,50, linmap=linmap)
    dataset = MandelbrotDataSet(2000000, max_depth=1500, gpu=True)
    train(model, dataset, 1, batch_size=8000, use_scheduler=True, oversample=0.1, snapshots_every=500, vm=vidmaker)

def create_dataset():
    dataset = MandelbrotDataSet(100000, max_depth=50, gpu=True)
    dataset.save('1M_50_test')

if __name__ == "__main__":
    example_train_capture()