In [1]:
import ipywebgl
import numpy as np

In [2]:
w = ipywebgl.GLViewer()
w.clear_color(.8, .8, .8 ,1)
w.clear()
w.disable(depth_test=True)
w.execute_commands(execute_once=True)

https://www.shadertoy.com/view/slSXRW

In [3]:
screen_vbo = w.create_buffer_ext(
    src_data=np.array(
      [-1, 1,
        -1, -1,
        1, -1,
        -1, 1,
        1, -1,
        1, 1,], dtype=np.float32).flatten()
)

screen_vao = w.create_vertex_array_ext(
    None,
    [
        (screen_vbo, '2f32', 0),
    ]
)

In [4]:
common_pixel_shader = """#version 300 es
precision highp float;

const float PI = 3.14159265358;

// Units are in megameters.
const float groundRadiusMM = 6.360;
const float atmosphereRadiusMM = 6.460;

// 200M above the ground.
const vec3 viewPos = vec3(0.0, groundRadiusMM + 0.0002, 0.0);

const vec2 tLUTRes = vec2(256.0, 64.0);
const vec2 msLUTRes = vec2(32.0, 32.0);
// Doubled the vertical skyLUT res from the paper, looks way
// better for sunrise.
const vec2 skyLUTRes = vec2(200.0, 200.0);

const vec3 groundAlbedo = vec3(0.3);

// These are per megameter.
const vec3 rayleighScatteringBase = vec3(5.802, 13.558, 33.1);
const float rayleighAbsorptionBase = 0.0;

const float mieScatteringBase = 3.996;
const float mieAbsorptionBase = 4.4;

const vec3 ozoneAbsorptionBase = vec3(0.650, 1.881, .085);

/*
 * Animates the sun movement.
 */
float getSunAltitude(float time)
{
    const float periodSec = 120.0;
    const float halfPeriod = periodSec / 2.0;
    const float sunriseShift = 0.1;
    float cyclePoint = (1.0 - abs((mod(time,periodSec)-halfPeriod)/halfPeriod));
    cyclePoint = (cyclePoint*(1.0+sunriseShift))-sunriseShift;
    return (0.5*PI)*cyclePoint;
}
vec3 getSunDir(float time)
{
    float altitude = getSunAltitude(time);
    return normalize(vec3(0.0, sin(altitude), -cos(altitude)));
}

float getMiePhase(float cosTheta) {
    const float g = 0.8;
    const float scale = 3.0/(8.0*PI);
    
    float num = (1.0-g*g)*(1.0+cosTheta*cosTheta);
    float denom = (2.0+g*g)*pow((1.0 + g*g - 2.0*g*cosTheta), 1.5);
    
    return scale*num/denom;
}

float getRayleighPhase(float cosTheta) {
    const float k = 3.0/(16.0*PI);
    return k*(1.0+cosTheta*cosTheta);
}

void getScatteringValues(vec3 pos, 
                         out vec3 rayleighScattering, 
                         out float mieScattering,
                         out vec3 extinction) {
    float altitudeKM = (length(pos)-groundRadiusMM)*1000.0;
    // Note: Paper gets these switched up.
    float rayleighDensity = exp(-altitudeKM/8.0);
    float mieDensity = exp(-altitudeKM/1.2);
    
    rayleighScattering = rayleighScatteringBase*rayleighDensity;
    float rayleighAbsorption = rayleighAbsorptionBase*rayleighDensity;
    
    mieScattering = mieScatteringBase*mieDensity;
    float mieAbsorption = mieAbsorptionBase*mieDensity;
    
    vec3 ozoneAbsorption = ozoneAbsorptionBase*max(0.0, 1.0 - abs(altitudeKM-25.0)/15.0);
    
    extinction = rayleighScattering + rayleighAbsorption + mieScattering + mieAbsorption + ozoneAbsorption;
}

float safeacos(const float x) {
    return acos(clamp(x, -1.0, 1.0));
}

// From https://gamedev.stackexchange.com/questions/96459/fast-ray-sphere-collision-code.
float rayIntersectSphere(vec3 ro, vec3 rd, float rad) {
    float b = dot(ro, rd);
    float c = dot(ro, ro) - rad*rad;
    if (c > 0.0f && b > 0.0) return -1.0;
    float discr = b*b - c;
    if (discr < 0.0) return -1.0;
    // Special case: inside sphere, use far discriminant
    if (discr > b*b) return (-b + sqrt(discr));
    return -b - sqrt(discr);
}

/*
 * Same parameterization here.
 */
vec3 getValFromTLUT(sampler2D tex, vec2 bufferRes, vec3 pos, vec3 sunDir) {
    float height = length(pos);
    vec3 up = pos / height;
	float sunCosZenithAngle = dot(sunDir, up);
    vec2 uv = vec2(tLUTRes.x*clamp(0.5 + 0.5*sunCosZenithAngle, 0.0, 1.0),
                   tLUTRes.y*max(0.0, min(1.0, (height - groundRadiusMM)/(atmosphereRadiusMM - groundRadiusMM))));
    uv /= bufferRes;
    return texture(tex, uv).rgb;
}
vec3 getValFromMultiScattLUT(sampler2D tex, vec2 bufferRes, vec3 pos, vec3 sunDir) {
    float height = length(pos);
    vec3 up = pos / height;
	float sunCosZenithAngle = dot(sunDir, up);
    vec2 uv = vec2(msLUTRes.x*clamp(0.5 + 0.5*sunCosZenithAngle, 0.0, 1.0),
                   msLUTRes.y*max(0.0, min(1.0, (height - groundRadiusMM)/(atmosphereRadiusMM - groundRadiusMM))));
    uv /= bufferRes;
    return texture(tex, uv).rgb;
}
"""

In [11]:
transmittance_prog = w.create_program_ext(
"""#version 300 es

in vec2 in_vert;

void main() {
    gl_Position = vec4(in_vert, 0, 1);
}
"""
,
common_pixel_shader + """
// Buffer A generates the Transmittance LUT. Each pixel coordinate corresponds to a height and sun zenith angle, and
// the value is the transmittance from that point to sun, through the atmosphere.
const float sunTransmittanceSteps = 40.0;

vec3 getSunTransmittance(vec3 pos, vec3 sunDir) {
    if (rayIntersectSphere(pos, sunDir, groundRadiusMM) > 0.0) {
        return vec3(0.0);
    }
    
    float atmoDist = rayIntersectSphere(pos, sunDir, atmosphereRadiusMM);
    float t = 0.0;
    
    vec3 transmittance = vec3(1.0);
    for (float i = 0.0; i < sunTransmittanceSteps; i += 1.0) {
        float newT = ((i + 0.3)/sunTransmittanceSteps)*atmoDist;
        float dt = newT - t;
        t = newT;
        
        vec3 newPos = pos + t*sunDir;
        
        vec3 rayleighScattering, extinction;
        float mieScattering;
        getScatteringValues(newPos, rayleighScattering, mieScattering, extinction);
        
        transmittance *= exp(-dt*extinction);
    }
    return transmittance;
}

out vec4 fragColor;

void main()
{
    float u = gl_FragCoord.x/tLUTRes.x;
    float v = gl_FragCoord.y/tLUTRes.y;
    
    float sunCosTheta = 2.0*u - 1.0;
    float sunTheta = safeacos(sunCosTheta);
    float height = mix(groundRadiusMM, atmosphereRadiusMM, v);
    
    vec3 pos = vec3(0.0, height, 0.0); 
    vec3 sunDir = normalize(vec3(0.0, sunCosTheta, -sin(sunTheta)));
    
    fragColor = vec4(getSunTransmittance(pos, sunDir), 1.0);
}
""",
{
    'in_vert' : 0,
})

In [6]:
transmittance_buffer = w.create_framebuffer()
w.bind_framebuffer('FRAMEBUFFER', transmittance_buffer)

transmittance_lut = w.create_texture()
w.active_texture(0)
w.bind_texture('TEXTURE_2D', transmittance_lut)
w.tex_parameter('TEXTURE_2D', 'TEXTURE_MAG_FILTER', 'NEAREST')
w.tex_parameter('TEXTURE_2D', 'TEXTURE_MIN_FILTER', 'NEAREST')
w.tex_parameter('TEXTURE_2D', 'TEXTURE_WRAP_S', 'CLAMP_TO_EDGE')
w.tex_parameter('TEXTURE_2D', 'TEXTURE_WRAP_T', 'CLAMP_TO_EDGE')
w.tex_storage_2d('TEXTURE_2D', 1, 'RGBA16F', 256, 64)
w.framebuffer_texture_2d('FRAMEBUFFER', 'COLOR_ATTACHMENT0', 'TEXTURE_2D', transmittance_lut, 0)

w.bind_framebuffer('FRAMEBUFFER', None)
w.execute_commands(execute_once=True)

In [12]:
w.bind_framebuffer('FRAMEBUFFER', transmittance_buffer)
w.viewport(0, 0, 256, 64)
w.clear()
w.use_program(transmittance_prog)
w.bind_vertex_array(screen_vao)
w.draw_arrays('TRIANGLES',0, 6)
w.bind_framebuffer('FRAMEBUFFER', None)
w.viewport(0, 0, w.width, w.height)
w.execute_commands(execute_once=True)

In [8]:
show_texture_prog = w.create_program_ext(
"""#version 300 es

in vec2 in_vert;

void main() {
    gl_Position = vec4(in_vert, 0, 1);
}
"""
,
"""#version 300 es
precision highp float;

uniform sampler2D u_texture;

out vec4 color;

void main() {
    ivec2 size = textureSize(u_texture, 0); //so we can display smaller textures
    color = vec4(texelFetch(u_texture, ivec2(gl_FragCoord.xy) % size, 0).rgb, 1.0);
    //color = vec4(gl_FragCoord.xyz,1);
}
""",
{
    'in_vert' : 0,
})

In [13]:
# display the texture
w.bind_framebuffer('FRAMEBUFFER', None)
w.disable(depth_test=True)
w.clear()
w.use_program(show_texture_prog)
w.bind_vertex_array(screen_vao)
w.uniform('u_texture', np.array([0], dtype=np.int32))
w.draw_arrays('TRIANGLES',0, 6)

w.execute_commands()
w

GLViewer(camera_pos=[0, 50, 200])